diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/tensor_zip.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/tensor_zip.hpp deleted file mode 100644 index 279c4054d414e99829b8bdb0fb6e6875e1a9b9e3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/tensor_zip.hpp +++ /dev/null @@ -1,246 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include // CUTE_HOST_DEVICE -#include // cute::Tensor -#include // cute::tuple - -namespace cute -{ - -// A tuple of Iterators that can be offset asymmetrically -// Note that this only accepts op+(tuple) and op[tuple] -// where each iterator will be offset by its respective index only. -// READ-ONLY for now until cute::tuple can be constructed with references. -template -struct ZipIterator -{ - using value_type = cute::tuple...>; - using element_type = cute::tuple...>; - // NOTE: cute::tuple does not support constructions with references at the moment. - // Consider fixes and/or an implementation of std::forward_as_tuple. - // For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY. - //using reference = cute::tuple...>; - using reference = value_type; - - ZipIterator() = delete; - - CUTE_HOST_DEVICE constexpr - ZipIterator(Iters... iters) - : iters_(iters...) - {} - - CUTE_HOST_DEVICE constexpr - ZipIterator(cute::tuple const& iters) - : iters_(iters) - {} - - CUTE_HOST_DEVICE constexpr - reference operator*() const { - return cute::apply(iters_, [](auto&&... args) { return reference(*args...); }); - } - - template - CUTE_HOST_DEVICE constexpr - ZipIterator operator+(cute::tuple const& idxs) const { - static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); - return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); - } - - template - CUTE_HOST_DEVICE constexpr - reference operator[](cute::tuple const& idxs) const { - return *(*this + idxs); - } - - cute::tuple iters_; -}; - -//------------------------------------------------------------------------------ -// type traits - -template -struct is_rmem> : conjunction...> {}; -template -struct is_smem> : conjunction...> {}; -template -struct is_gmem> : conjunction...> {}; -template -struct is_tmem> : conjunction...> {}; - -// A tuple of Layouts that operates on each Layout symmetrically -// The Layouts need to have compatible shapes and ranks. -// The ZipLayout presents the intersection of the domain of its component Layouts. -// E.g. all Layouts accept 1D coords and ZipLayout does as well. -// The ZipLayout returns the union of the codomain of its component Layouts. -// E.g. all Layouts return an integer so ZipLayout returns a tuple of integers. -template -struct ZipLayout -{ - static constexpr int rank = (int(0) | ... | Layouts::rank); - - static_assert((is_layout::value && ...), "All template parameters must be layouts"); - static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank"); - - CUTE_HOST_DEVICE constexpr - ZipLayout(Layouts const&... layouts) - : layouts_(layouts...) - {} - - CUTE_HOST_DEVICE constexpr - ZipLayout(cute::tuple const& layouts) - : layouts_(layouts) - {} - - template - CUTE_HOST_DEVICE constexpr - auto - operator()(Coord const& coord) const { - if constexpr (has_underscore::value) { - return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); })); - } else { - return cute::transform(layouts_, [&] (auto layout) { return layout(coord); }); - } - - CUTE_GCC_UNREACHABLE; - } - - // op() convenience function for multi-dimensional coordinates - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { - return operator()(make_coord(c0,c1,cs...)); - } - - cute::tuple layouts_; -}; - -template -struct is_layout> : true_type {}; - -// -// make_zip_tensor and unzip_tensor -// - -template -CUTE_HOST_DEVICE constexpr -auto -make_zip_tensor(Tensor const&... tensors) -{ - return make_tensor(ZipIterator(tensors.data()...), - ZipLayout(tensors.layout()...)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -unzip_tensor(Tensor const& tensor) -{ - return cute::transform(tensor.data().iters_, tensor.layout().layouts_, - [](auto iter, auto layout) { return make_tensor(iter, layout); }); -} - -// -// Utilities -// - -template -CUTE_HOST_DEVICE constexpr -auto -rank(ZipLayout const& layouts) -{ - return rank(get<0>(layouts.layouts_)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -size(ZipLayout const& layouts) -{ - return size(get<0>(layouts.layouts_)); -} - -// -// Manipulation -// - -// Extend each component layout to rank-N by appending Layout @a x. -template -CUTE_HOST_DEVICE constexpr -auto -append(ZipLayout const& layouts, - Layout const& x = {}) -{ - return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append(t, x); })); -} - -// Extend each component layout to rank-N by prepending Layout @a x. -template -CUTE_HOST_DEVICE constexpr -auto -prepend(ZipLayout const& layouts, - Layout const& x = {}) -{ - return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend(t, x); })); -} - -template -CUTE_HOST_DEVICE constexpr -auto -logical_divide(ZipLayout const& layouts, - Tiler const& tiler) -{ - return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); })); -} - -template -CUTE_HOST_DEVICE constexpr -auto -zipped_divide(ZipLayout const& layouts, - Tiler const& tiler) -{ - return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); })); -} - -// Return by calling slice_and_offset and all component layouts. -template -CUTE_HOST_DEVICE constexpr -auto -slice_and_offset(Coord const& c, ZipLayout const& layouts) -{ - auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); })); - return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); -} - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/underscore.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/underscore.hpp deleted file mode 100644 index 8a83b867c9b1d70137ae7146e9e3197c1748e350..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/underscore.hpp +++ /dev/null @@ -1,194 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // CUTE_INLINE_CONSTANT, CUTE_HOST_DEVICE -#include // cute::is_tuple -#include // cute::false_type, cute::true_type - -namespace cute -{ - -// For slicing -struct Underscore : Int<0> {}; - -CUTE_INLINE_CONSTANT Underscore _; - -// Convenient alias -using X = Underscore; - -// Treat Underscore as an integral like integral_constant -template <> -struct is_integral : true_type {}; - -template -struct is_underscore : false_type {}; -template <> -struct is_underscore : true_type {}; - -// Tuple trait for detecting static member element -template -struct has_elem : false_type {}; -template -struct has_elem : true_type {}; -template -struct has_elem::value> > - : has_elem > {}; -template -struct has_elem> - : disjunction, Elem>...> {}; - -// Tuple trait for detecting static member element -template -struct all_elem : false_type {}; -template -struct all_elem : true_type {}; -template -struct all_elem::value> > - : all_elem > {}; -template -struct all_elem> - : conjunction, Elem>...> {}; - -// Tuple trait for detecting Underscore member -template -using has_underscore = has_elem; - -template -using all_underscore = all_elem; - -template -using has_int1 = has_elem>; - -template -using has_int0 = has_elem>; - -// -// Slice keeps only the elements of Tuple B that are paired with an Underscore -// - -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -lift_slice(A const& a, B const& b) -{ - if constexpr (is_tuple::value) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_slice(x,y); }); - } else if constexpr (is_underscore::value) { - return cute::tuple{b}; - } else { - return cute::tuple<>{}; - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - -// Entry point overrides the lifting so that slice(_,b) == b -template -CUTE_HOST_DEVICE constexpr -auto -slice(A const& a, B const& b) -{ - if constexpr (is_tuple::value) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_slice(x,y); }); - } else if constexpr (is_underscore::value) { - return b; - } else { - return cute::tuple<>{}; - } - - CUTE_GCC_UNREACHABLE; -} - -// -// Dice keeps only the elements of Tuple B that are paired with an Int -// - -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -lift_dice(A const& a, B const& b) -{ - if constexpr (is_tuple::value) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_dice(x,y); }); - } else if constexpr (is_underscore::value) { - return cute::tuple<>{}; - } else { - return cute::tuple{b}; - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - -// Entry point overrides the lifting so that dice(1,b) == b -template -CUTE_HOST_DEVICE constexpr -auto -dice(A const& a, B const& b) -{ - if constexpr (is_tuple::value) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_dice(x,y); }); - } else if constexpr (is_underscore::value) { - return cute::tuple<>{}; - } else { - return b; - } - - CUTE_GCC_UNREACHABLE; -} - -// -// Display utilities -// - -CUTE_HOST_DEVICE void print(Underscore const&) { - printf("_"); -} - -#if !defined(__CUDACC_RTC__) -CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { - return os << "_"; -} -#endif - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/debug.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/debug.hpp deleted file mode 100644 index 5e704b2599858e15590cd545cc5c28a9e289f935..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/debug.hpp +++ /dev/null @@ -1,164 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -/** - * \file - * \brief Debugging and logging functionality - */ - -#include - -#include - -namespace cute -{ - -/****************************************************************************** - * Debug and logging macros - ******************************************************************************/ - -/** - * Formats and prints the given message to stdout - */ -#if !defined(CUTE_LOG) -# if !defined(__CUDA_ARCH__) -# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) -# else -# define CUTE_LOG(format, ...) \ - printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ - blockIdx.x, blockIdx.y, blockIdx.z, \ - threadIdx.x, threadIdx.y, threadIdx.z, \ - __VA_ARGS__); -# endif -#endif - -/** - * Formats and prints the given message to stdout only if DEBUG is defined - */ -#if !defined(CUTE_LOG_DEBUG) -# ifdef DEBUG -# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) -# else -# define CUTE_LOG_DEBUG(format, ...) -# endif -#endif - -/** - * \brief Perror macro with exit - */ -#if !defined(CUTE_ERROR_EXIT) -# define CUTE_ERROR_EXIT(e) \ - do { \ - cudaError_t code = (e); \ - if (code != cudaSuccess) { \ - fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ - __FILE__, __LINE__, #e, \ - cudaGetErrorName(code), cudaGetErrorString(code)); \ - fflush(stderr); \ - exit(1); \ - } \ - } while (0) -#endif - -#if !defined(CUTE_CHECK_LAST) -# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) -#endif - -#if !defined(CUTE_CHECK_ERROR) -# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) -#endif - -// A dummy function that uses compilation failure to print a type -template -CUTE_HOST_DEVICE void -print_type() { - static_assert(sizeof...(T) < 0, "Printing type T."); -} - -template -CUTE_HOST_DEVICE void -print_type(T&&...) { - static_assert(sizeof...(T) < 0, "Printing type T."); -} - -// -// Device-specific helpers -// -// e.g. -// if (thread0()) print(...); -// if (block0()) print(...); -// if (thread(42)) print(...); - -CUTE_HOST_DEVICE -bool -block([[maybe_unused]] int bid) -{ -#if defined(__CUDA_ARCH__) - return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == static_cast(bid); -#else - return true; -#endif -} - -CUTE_HOST_DEVICE -bool -thread([[maybe_unused]] int tid, [[maybe_unused]] int bid) -{ -#if defined(__CUDA_ARCH__) - return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == static_cast(tid)) && block(bid); -#else - return true; -#endif -} - -CUTE_HOST_DEVICE -bool -thread(int tid) -{ - return thread(tid,0); -} - -CUTE_HOST_DEVICE -bool -thread0() -{ - return thread(0,0); -} - -CUTE_HOST_DEVICE -bool -block0() -{ - return block(0); -} - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print.hpp deleted file mode 100644 index e6cc887adc5a77690645b63e45cbf88a8a99e105..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print.hpp +++ /dev/null @@ -1,266 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // CUTE_HOST_DEVICE -#include // cute::is_valid -#include - -// -// CUDA compatible print and printf -// - -namespace cute -{ - -CUTE_HOST_DEVICE -int -num_digits(int x) -{ - return (x < 10 ? 1 : - (x < 100 ? 2 : - (x < 1000 ? 3 : - (x < 10000 ? 4 : - (x < 100000 ? 5 : - (x < 1000000 ? 6 : - (x < 10000000 ? 7 : - (x < 100000000 ? 8 : - (x < 1000000000 ? 9 : - 10))))))))); -} - -// -// print dispatcher -// - -CUTE_HOST_DEVICE -void -print(char c) { - printf("%c", c); -} - -CUTE_HOST_DEVICE -void -print(signed char a) { - printf("%d", static_cast(a)); -} - -CUTE_HOST_DEVICE -void -print(unsigned char a) { - printf("%u", static_cast(a)); -} - -CUTE_HOST_DEVICE -void -print(short a) { - printf("%hd", a); -} - -CUTE_HOST_DEVICE -void -print(unsigned short a) { - printf("%hu", a); -} - -CUTE_HOST_DEVICE -void -print(int a) { - printf("%d", a); -} - -CUTE_HOST_DEVICE -void -print(uint1b_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(int2b_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(uint2b_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(int4b_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(uint4b_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(bin1_t a) { - printf("%d", int(a)); -} - -CUTE_HOST_DEVICE -void -print(unsigned int a) { - printf("%u", a); -} - -CUTE_HOST_DEVICE -void -print(long a) { - printf("%ld", a); -} - -CUTE_HOST_DEVICE -void -print(unsigned long a) { - printf("%lu", a); -} - -CUTE_HOST_DEVICE -void -print(long long a) { - printf("%lld", a); -} - -CUTE_HOST_DEVICE -void -print(unsigned long long a) { - printf("%llu", a); -} - -CUTE_HOST_DEVICE -void -print(float a) { - printf("%f", a); -} - -CUTE_HOST_DEVICE -void -print(double a) { - printf("%f", a); -} - -template -CUTE_HOST_DEVICE -void -print(char const* format, T const&... t) { - printf(format, t...); -} - -CUTE_HOST_DEVICE -void -print(char const* format) { - printf("%s", format); -} - -// -// pretty printing -// - -CUTE_HOST_DEVICE void -pretty_print(uint1b_t a) { - printf("%*d", 3, int(a)); -} - -CUTE_HOST_DEVICE void -pretty_print(int2b_t a) { - printf("%*d", 5, int(a)); -} - -CUTE_HOST_DEVICE void -pretty_print(uint2b_t a) { - printf("%*d", 5, int(a)); -} - -CUTE_HOST_DEVICE void -pretty_print(int4b_t a) { - printf("%*d", 5, int(a)); -} - -CUTE_HOST_DEVICE void -pretty_print(uint4b_t a) { - printf("%*d", 5, int(a)); -} - -CUTE_HOST_DEVICE void -pretty_print(bool v) { - printf("%*d", 3, int(v)); -} - -CUTE_HOST_DEVICE void -pretty_print(int32_t v) { - printf("%*d", 5, v); -} - -CUTE_HOST_DEVICE void -pretty_print(uint32_t v) { - printf("%*d", 5, v); -} - -CUTE_HOST_DEVICE void -pretty_print(int64_t v) { - printf("%*lld", 5, static_cast(v)); -} - -CUTE_HOST_DEVICE void -pretty_print(uint64_t v) { - printf("%*llu", 5, static_cast(v)); -} - -CUTE_HOST_DEVICE void -pretty_print(float v) { - printf("%*.2e", 10, v); -} - -CUTE_HOST_DEVICE void -pretty_print(double v) { - printf("%*.3e", 11, v); -} - -template -CUTE_HOST_DEVICE void -pretty_print(T t) { - constexpr auto has_print_exmy_base = cute::is_valid([](auto t) -> decltype(pretty_print_float_exmy_base(t)) {}, t); - if constexpr (has_print_exmy_base) { - pretty_print_float_exmy_base(t); - } else { - printf(" "); print(t); - } -} - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_latex.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_latex.hpp deleted file mode 100644 index 28e30ed56fbb6cb7b334447ab89db3b4a1428684..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_latex.hpp +++ /dev/null @@ -1,438 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // CUTE_HOST_DEVICE - -#include -#include - -#include -#include - -namespace cute -{ - -/////////////////////////////////////// -// Common LaTeX TikZ Color utilities // -/////////////////////////////////////// - -struct TikzColor_White { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - return "white"; - } -}; - -struct TikzColor_BWx8 { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", - "black!10", "black!50", "black!30", "black!70"}; - return color_map[idx % 8]; - } -}; - -struct TikzColor_TV { - CUTE_HOST_DEVICE char const* - operator()(int tid, int vid) const { - static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - return color_map[tid % 8]; - } -}; - -///////////////////////////// -// Layout 2D to LaTeX TikZ // -///////////////////////////// - -template -CUTE_HOST_DEVICE -void -print_latex(LayoutA const& layout_a, // (m,n) -> idx - TikzColorFn color = {}) // lambda(idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); - auto layout = append<2>(layout_a, Layout<_1,_0>{}); - - // Commented print(layout) - printf("%% Layout: "); print(layout); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - auto [M, N] = product_each(shape(layout)); - - // Layout - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - int idx = layout(m,n); - printf("\\node[fill=%s] at (%d,%d) {%d};\n", - color(idx), m, n, idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", - int(M), int(N)); - // Labels - for (int m = 0, n = -1; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); - } - for (int m = -1, n = 0; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -template -CUTE_HOST_DEVICE -void -print_latex(ComposedLayout,Layout> const& layout, - TikzColorFn color = {}) // lambda(idx) -> tikz color string) -{ - print_latex(as_position_independent_swizzle_layout(layout), color); -} - -/////////////////////////////// -// LayoutTV 2D to LaTeX TikZ // -/////////////////////////////// - -template -CUTE_HOST_DEVICE -void -print_latex_tv(LayoutTV const& layout_tv, // (t,v) -> m,n coord - Tile_MN const& tile_mn, // (M,N) - TikzColorFn color = {}) // (t,v) -> color -{ - CUTE_STATIC_ASSERT_V(rank(layout_tv) == Int<2>{}); - - // Commented prints - printf("%% Layout TV: "); print(layout_tv); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - auto [M, N] = product_each(shape(tile_mn)); - Tensor filled = make_tensor(make_shape(M, N)); - clear(filled); - - // Layout - for (int tid = 0; tid < size<0>(layout_tv); ++tid) { - for (int vid = 0; vid < size<1>(layout_tv); ++vid) { - auto [m, n] = layout_tv(tid, vid); - if (not filled(m, n)) { - filled(m, n) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(m), int(n), - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", int(M), int(N)); - // Labels - for (int m = 0, n = -1; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); - } - for (int n = 0, m = -1; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); - } - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -//////////////////////////// -// MMA Atom to LaTeX TikZ // -//////////////////////////// - -namespace detail { - -template -CUTE_HOST_DEVICE -void -print_latex_mma(LayoutC const& C, // (tid,vid) -> (m,n) coord - LayoutA const& A, // (tid,vid) -> (m,k) coord - LayoutB const& B, // (tid,vid) -> (n,k) coord - Tile_MNK const& tile_mnk, // (M,N,K) - TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - // Commented prints - printf("%% LayoutC: "); print(C); printf("\n"); - printf("%% LayoutA: "); print(A); printf("\n"); - printf("%% LayoutB: "); print(B); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - auto [M, N, K] = product_each(shape(tile_mnk)); - Tensor filled = make_tensor(make_shape(M, N, K)); - clear(filled); - - // C starting at 0,0 - for (int tid = 0; tid < size<0>(C); ++tid) { - for (int vid = 0; vid < size<1>(C); ++vid) { - auto [m, n] = C(tid, vid); - if (not filled(m, n, 0)) { - filled(m, n, 0) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(m), int(n), - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(M), int(N)); - - clear(filled); - - // A starting at 0,-K-1 - for (int tid = 0; tid < size<0>(A); ++tid) { - for (int vid = 0; vid < size<1>(A); ++vid) { - auto [m, k] = A(tid, vid); - if (not filled(m, 0, k)) { - filled(m, 0, k) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(m), int(k-K-1), - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, -int(K)-1, int(M), -1); - // A labels - for (int m = 0, k = -1; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m); - } - for (int m = -1, k = 0; k < K; ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k); - } - - clear(filled); - - // B starting at -K-1,0 - for (int tid = 0; tid < size<0>(B); ++tid) { - for (int vid = 0; vid < size<1>(B); ++vid) { - auto [n, k] = B(tid, vid); - if (not filled(0, n, k)) { - filled(0, n, k) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(k)-int(K)-1, int(n), - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - -int(K)-1, 0, -1, int(N)); - // B labels - for (int n = 0, k = -1; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n); - } - for (int n = -1, k = 0; k < K; ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -} // end namespace detail - -// MMA Atom to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(MMA_Atom const& mma_atom, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - print_latex(make_tiled_mma(mma_atom)); -} - -// TiledMMA to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(TiledMMA const& mma, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - auto tile_mnk = tile_shape(mma); - - Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); - Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); - - Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); - Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); - - Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); - Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); - - detail::print_latex_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); -} - -//////////////////////////// -// CopyAtom to LaTeX TikZ // -//////////////////////////// - -namespace detail { - -// Generic TV Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex_copy(LayoutS_TV const& S, // (t,v) -> m,n coord - LayoutD_TV const& D, // (t,v) -> m,n coord - Tile_MN const& tile_mn, // (M,N) - TikzColorFn color = {}) // (t,v) -> color -{ - CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); - - // Commented prints - printf("%% Layout S TV: "); print(S); printf("\n"); - printf("%% Layout D TV: "); print(D); printf("\n"); - - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - auto [M, N] = product_each(shape(tile_mn)); - Tensor filled = make_tensor(make_shape(M, N)); - clear(filled); - - // S starting at 0,0 - for (int tid = 0; tid < size<0>(S); ++tid) { - for (int vid = 0; vid < size<1>(S); ++vid) { - auto [m, n] = S(tid, vid); - if (not filled(m, n)) { - filled(m, n) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(m), int(n), - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(M), int(N)); - // S Labels - for (int m = 0, n = -1; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); - } - for (int m = -1, n = 0; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); - } - - clear(filled); - - // D starting at 0,N+3 - for (int tid = 0; tid < size<0>(D); ++tid) { - for (int vid = 0; vid < size<1>(D); ++vid) { - auto [m, n] = D(tid, vid); - if (not filled(m, n)) { - filled(m, n) = true; - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(tid, vid), - int(m), int(n) + int(N) + 3, - tid, vid); - } - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, int(N) + 3, int(M), int(N) + int(N) + 3); - // D Labels - for (int m = 0, n = N; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), m); - } - for (int m = -1, n = 0; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), n); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -} // end namespace detail - -// TiledCopy to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(TiledCopy const& copy, - TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string -{ - auto tiler_mn = typename TiledCopy::Tiler_MN{}; - auto tile_mn = product_each(shape(logical_divide(make_layout(Shape<_1,_1>{}), tiler_mn))); // tile_shape - - Tensor refS = make_identity_tensor(tile_mn); - Tensor layoutS_TV = copy.tidfrg_S(refS)(_,_,Int<0>{}); - - Tensor refD = make_identity_tensor(tile_mn); - Tensor layoutD_TV = copy.tidfrg_D(refD)(_,_,Int<0>{}); - - detail::print_latex_copy(layoutS_TV, layoutD_TV, tile_mn, color); -} - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_svg.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_svg.hpp deleted file mode 100644 index 5d26809ea69a878926585f435de68bb7751e41fd..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_svg.hpp +++ /dev/null @@ -1,257 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // CUTE_HOST_DEVICE - -#include -#include - -#include -#include - -namespace cute -{ - -//////////////////////////////// -// Common SVG Color utilities // -//////////////////////////////// - -struct TSVGColor_White { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - return "255,255,255"; - } -}; - -struct TSVGColor_BWx8 { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - static char const* color_map[8] = {"255,255,255", "230,230,230", "205,205,205", "180,180,180", - "155,155,155", "130,130,130", "105,105,105", "080,080,080"}; - return color_map[idx % 8]; - } -}; - -struct SVGColor_TV { - CUTE_HOST_DEVICE char const* - operator()(int tid, int vid) const { - static char const* color_map[8] = {"175,175,255", "175,255,175", "255,255,175", "255,175,175", - "210,210,255", "210,255,210", "255,255,210", "255,210,210"}; - return color_map[tid % 8]; - } -}; - -///////////////////// -// MMA Atom to SVG // -///////////////////// - -namespace detail { - -template -CUTE_HOST_DEVICE -void -print_svg_mma(LayoutC const& C, - LayoutA const& A, - LayoutB const& B, - Tile_MNK const& tile_mnk, - SVGColorFn color = {}) // lambda(tid,vid) -> SVG color string -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - auto [M, N, K] = product_each(shape(tile_mnk)); - - int cell_size = 20; - - int page_width = (K + N + 2) * cell_size; - int page_height = (K + M + 2) * cell_size; - - // Commented print - printf("\n"); - printf("\n"); - printf("\n"); - printf("\n"); - - // SVG Header - printf("\n", - page_width, page_height); - - Tensor filled = make_tensor(make_shape(M, N, K)); - clear(filled); - - // --- Draw C --- - for (int tid = 0; tid < size<0>(C); ++tid) { - for (int vid = 0; vid < size<1>(C); ++vid) { - auto [m, n] = C(tid, vid); - if (!filled(m, n, 0)) { - filled(m, n, 0) = true; - - int x = (n + K + 2) * cell_size; - int y = (m + K + 2) * cell_size; - - printf("\n", - x, y, cell_size, cell_size, color(tid,vid)); - printf("T%d\n", - x + cell_size/2, y + 1*cell_size/4, tid); - printf("V%d\n", - x + cell_size/2, y + 3*cell_size/4, vid); - } - } - } - - clear(filled); - - // --- Draw A --- - for (int tid = 0; tid < size<0>(A); ++tid) { - for (int vid = 0; vid < size<1>(A); ++vid) { - auto [m, k] = A(tid, vid); - if (!filled(m, 0, k)) { - filled(m, 0, k) = true; - - int x = (k + 1) * cell_size; - int y = (m + K + 2) * cell_size; - - printf("\n", - x, y, cell_size, cell_size, color(tid,vid)); - printf("T%d\n", - x + cell_size/2, y + 1*cell_size/4, tid); - printf("V%d\n", - x + cell_size/2, y + 3*cell_size/4, vid); - } - } - } - - // A labels - for (int m = 0, k = -1; m < M; ++m) { - int x = (k + 1) * cell_size; - int y = (m + K + 2) * cell_size; - printf("%d\n", - x + cell_size/2, y + cell_size/2, m); - } - for (int m = -1, k = 0; k < K; ++k) { - int x = (k + 1) * cell_size; - int y = (m + K + 2) * cell_size; - printf("%d\n", - x + cell_size/2, y + cell_size/2, k); - } - - clear(filled); - - // --- Draw B --- - for (int tid = 0; tid < size<0>(B); ++tid) { - for (int vid = 0; vid < size<1>(B); ++vid) { - auto [n, k] = B(tid, vid); - if (!filled(0, n, k)) { - filled(0, n, k) = true; - - int x = (n + K + 2) * cell_size; - int y = (k + 1) * cell_size; - - printf("\n", - x, y, cell_size, cell_size, color(tid,vid)); - printf("T%d\n", - x + cell_size/2, y + 1*cell_size/4, tid); - printf("V%d\n", - x + cell_size/2, y + 3*cell_size/4, vid); - } - } - } - - // B labels - for (int n = 0, k = -1; n < N; ++n) { - int x = (n + K + 2) * cell_size; - int y = (k + 1) * cell_size; - printf("%d\n", - x + cell_size/2, y + cell_size/2, n); - } - for (int n = -1, k = 0; k < K; ++k) { - int x = (n + K + 2) * cell_size; - int y = (k + 1) * cell_size; - printf("%d\n", - x + cell_size/2, y + cell_size/2, k); - } - - // SVG footer - printf("\n"); -} - -} // end namespace detail - -// MMA Atom to SVG -template -CUTE_HOST_DEVICE -void -print_svg(MMA_Atom const& mma_atom, - SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string -{ - print_svg(make_tiled_mma(mma_atom)); -} - -// TiledMMA to SVG -template -CUTE_HOST_DEVICE -void -print_svg(TiledMMA const& mma, - SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string -{ - auto tile_mnk = tile_shape(mma); - - Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); - Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); - - Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); - Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); - - Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); - Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); - - detail::print_svg_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); -} - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_tensor.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_tensor.hpp deleted file mode 100644 index c5eb39a1d59ff31389642cf86ab0e4961508170b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/print_tensor.hpp +++ /dev/null @@ -1,188 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // CUTE_HOST_DEVICE - -#include -#include - -namespace cute -{ - -//////////////////////////////// -// Layout 2D to Console table // -//////////////////////////////// - -template -CUTE_HOST_DEVICE -void -print_layout(Layout const& layout) // (m,n) -> idx -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - int idx_width = num_digits(cosize(layout)) + 2; - const char* delim = "+-----------------------"; - - print(layout); print("\n"); - - // Column indices - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } - printf("\n"); - - // Print out A m-by-n - for (int m = 0; m < size<0>(layout); ++m) { - // Header - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); - // Values - printf("%2d ", m); // Row indices - for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } - printf("|\n"); - } - // Footer - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); -} - -// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -template -CUTE_HOST_DEVICE -void -print_layout(ComposedLayout,Layout> const& layout) -{ - print_layout(as_position_independent_swizzle_layout(layout)); -} - -//////////////////////////////// -// Tensor 1D,2D,3D,4D Console // -//////////////////////////////// - -template -CUTE_HOST_DEVICE -void -print_tensor(Tensor const& tensor, bool print_type = true) -{ - if (print_type) { - print(tensor); print(":\n"); - } - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - pretty_print(tensor(m)); - printf("\n"); - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - pretty_print(tensor(m,n)); - } - printf("\n"); - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor(tensor(_,_,0), false); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); - print_tensor(tensor(_,_,k), false); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor(tensor(_,_,_,0), false); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); - print_tensor(tensor(_,_,_,p), false); - } - } -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST -std::ostream& -print_tensor_os(std::ostream& os, Tensor const& tensor) -{ - int digits = 9; - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - os << std::setw(digits) << tensor(m) << std::endl; - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - os << std::setw(digits) << tensor(m,n); - } - os << std::endl; - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor_os(os, tensor(_,_,0)); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; - print_tensor_os(os, tensor(_,_,k)); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor_os(os, tensor(_,_,_,0)); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; - print_tensor_os(os, tensor(_,_,_,p)); - } - } - - return os; -} - -template -CUTE_HOST -std::ostream& -operator<<(std::ostream& os, Tensor const& tensor) -{ - os << tensor.layout() << std::endl; - return print_tensor_os(os, tensor); -} -#endif // !defined(__CUDACC_RTC__) - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/type_traits.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/type_traits.hpp deleted file mode 100644 index 685babff7cdb2310debe2c713cb40a7fb8c130ea..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/util/type_traits.hpp +++ /dev/null @@ -1,322 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(type_traits) -#include CUDA_STD_HEADER(utility) -#include CUDA_STD_HEADER(cstddef) -#include CUDA_STD_HEADER(cstdint) -#include CUDA_STD_HEADER(limits) -#else -#include -#include // tuple_size, tuple_element -#include // ptrdiff_t -#include // uintptr_t -#include // numeric_limits -#endif - -#include // CUTE_STL_NAMESPACE - -namespace cute -{ - using CUTE_STL_NAMESPACE::enable_if; - using CUTE_STL_NAMESPACE::enable_if_t; -} - -#define __CUTE_REQUIRES(...) typename cute::enable_if<(__VA_ARGS__)>::type* = nullptr -#define __CUTE_REQUIRES_V(...) typename cute::enable_if::type* = nullptr - -namespace cute -{ - -// -using CUTE_STL_NAMESPACE::conjunction; -using CUTE_STL_NAMESPACE::conjunction_v; - -using CUTE_STL_NAMESPACE::disjunction; -using CUTE_STL_NAMESPACE::disjunction_v; - -using CUTE_STL_NAMESPACE::negation; -using CUTE_STL_NAMESPACE::negation_v; - -using CUTE_STL_NAMESPACE::void_t; -using CUTE_STL_NAMESPACE::is_void_v; - -using CUTE_STL_NAMESPACE::is_base_of; -using CUTE_STL_NAMESPACE::is_base_of_v; - -using CUTE_STL_NAMESPACE::is_const; -using CUTE_STL_NAMESPACE::is_const_v; -using CUTE_STL_NAMESPACE::is_volatile; -using CUTE_STL_NAMESPACE::is_volatile_v; - -// Defined in cute/numeric/integral_constant.hpp -// using CUTE_STL_NAMESPACE::true_type; -// using CUTE_STL_NAMESPACE::false_type; - -using CUTE_STL_NAMESPACE::conditional; -using CUTE_STL_NAMESPACE::conditional_t; - -using CUTE_STL_NAMESPACE::add_const_t; - -using CUTE_STL_NAMESPACE::remove_const_t; -using CUTE_STL_NAMESPACE::remove_cv_t; -using CUTE_STL_NAMESPACE::remove_reference_t; - -template -struct copy_cv { - using type = Dst; -}; - -template -struct copy_cv { - using type = Dst const; -}; - -template -struct copy_cv { - using type = Dst volatile; -}; - -template -struct copy_cv { - using type = Dst const volatile; -}; - -template -using copy_cv_t = typename copy_cv::type; - -using CUTE_STL_NAMESPACE::extent; -using CUTE_STL_NAMESPACE::remove_extent; - -using CUTE_STL_NAMESPACE::decay; -using CUTE_STL_NAMESPACE::decay_t; - -using CUTE_STL_NAMESPACE::is_lvalue_reference; -using CUTE_STL_NAMESPACE::is_lvalue_reference_v; - -using CUTE_STL_NAMESPACE::is_reference; -using CUTE_STL_NAMESPACE::is_trivially_copyable; - -using CUTE_STL_NAMESPACE::is_convertible; -using CUTE_STL_NAMESPACE::is_convertible_v; - -using CUTE_STL_NAMESPACE::is_same; -using CUTE_STL_NAMESPACE::is_same_v; - -using CUTE_STL_NAMESPACE::is_constructible; -using CUTE_STL_NAMESPACE::is_constructible_v; -using CUTE_STL_NAMESPACE::is_default_constructible; -using CUTE_STL_NAMESPACE::is_default_constructible_v; -using CUTE_STL_NAMESPACE::is_standard_layout; -using CUTE_STL_NAMESPACE::is_standard_layout_v; - -using CUTE_STL_NAMESPACE::is_arithmetic; -using CUTE_STL_NAMESPACE::is_unsigned; -using CUTE_STL_NAMESPACE::is_unsigned_v; -using CUTE_STL_NAMESPACE::is_signed; -using CUTE_STL_NAMESPACE::is_signed_v; - -using CUTE_STL_NAMESPACE::make_signed; -using CUTE_STL_NAMESPACE::make_signed_t; - -// using CUTE_STL_NAMESPACE::is_integral; -template -using is_std_integral = CUTE_STL_NAMESPACE::is_integral; - -using CUTE_STL_NAMESPACE::is_empty; -using CUTE_STL_NAMESPACE::is_empty_v; - -using CUTE_STL_NAMESPACE::invoke_result_t; - -using CUTE_STL_NAMESPACE::common_type; -using CUTE_STL_NAMESPACE::common_type_t; - -using CUTE_STL_NAMESPACE::remove_pointer; -using CUTE_STL_NAMESPACE::remove_pointer_t; - -using CUTE_STL_NAMESPACE::add_pointer; -using CUTE_STL_NAMESPACE::add_pointer_t; - -using CUTE_STL_NAMESPACE::alignment_of; -using CUTE_STL_NAMESPACE::alignment_of_v; - -using CUTE_STL_NAMESPACE::is_pointer; -using CUTE_STL_NAMESPACE::is_pointer_v; - -// -using CUTE_STL_NAMESPACE::declval; - -template -CUTE_HOST_DEVICE constexpr -T&& forward(remove_reference_t& t) noexcept -{ - return static_cast(t); -} - -template -CUTE_HOST_DEVICE constexpr -T&& forward(remove_reference_t&& t) noexcept -{ - static_assert(! is_lvalue_reference_v, "T cannot be an lvalue reference (e.g., U&)."); - return static_cast(t); -} - -template -CUTE_HOST_DEVICE constexpr -remove_reference_t&& move(T&& t) noexcept -{ - return static_cast&&>(t); -} - -// -using CUTE_STL_NAMESPACE::numeric_limits; - -// -using CUTE_STL_NAMESPACE::ptrdiff_t; - -// -using CUTE_STL_NAMESPACE::uintptr_t; - -// C++20 -// using std::remove_cvref; -template -struct remove_cvref { - using type = remove_cv_t>; -}; - -// C++20 -// using std::remove_cvref_t; -template -using remove_cvref_t = typename remove_cvref::type; - -// -// dependent_false -// -// @brief An always-false value that depends on one or more template parameters. -// See -// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf -// https://github.com/cplusplus/papers/issues/572 -// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html -template -inline constexpr bool dependent_false = false; - -// -// tuple_size, tuple_element -// -// @brief CuTe-local tuple-traits to prevent conflicts with other libraries. -// For cute:: types, we specialize std::tuple-traits, which is explicitly allowed. -// cute::tuple, cute::array, cute::array_subbyte, etc -// But CuTe wants to treat some external types as tuples as well. For those, -// we specialize cute::tuple-traits to avoid polluting external traits. -// dim3, uint3, etc - -template -struct tuple_size; - -template -struct tuple_size::type>> : CUTE_STL_NAMESPACE::integral_constant::value> {}; - -template -constexpr size_t tuple_size_v = tuple_size::value; - -template -struct tuple_element; - -template -struct tuple_element::type>> : CUTE_STL_NAMESPACE::tuple_element {}; - -template -using tuple_element_t = typename tuple_element::type; - -// -// is_valid -// - -namespace detail { - -template ()(declval()...))> -CUTE_HOST_DEVICE constexpr auto -is_valid_impl(int) { return CUTE_STL_NAMESPACE::true_type{}; } - -template -CUTE_HOST_DEVICE constexpr auto -is_valid_impl(...) { return CUTE_STL_NAMESPACE::false_type{}; } - -template -struct is_valid_fn { - template - CUTE_HOST_DEVICE constexpr auto - operator()(Args&&...) const { return is_valid_impl(int{}); } -}; - -} // end namespace detail - -template -CUTE_HOST_DEVICE constexpr auto -is_valid(F&&) { - return detail::is_valid_fn{}; -} - -template -CUTE_HOST_DEVICE constexpr auto -is_valid(F&&, Args&&...) { - return detail::is_valid_impl(int{}); -} - -template class True, template class False> -struct conditional_template { - template - using type = True; -}; - -template class True, template class False> -struct conditional_template { - template - using type = False; -}; - -// -// is_any_of -// - -// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us -template -struct is_any_of { - constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); -}; - -// Is true if and only if T is same as (is_same_v) at least one of the types in Us -template -inline constexpr bool is_any_of_v = is_any_of::value; - -} // end namespace cute diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/aligned_buffer.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/aligned_buffer.h deleted file mode 100644 index 8468f54b347ba004a317b9957dcbf966228e7deb..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/aligned_buffer.h +++ /dev/null @@ -1,129 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief AlignedBuffer is a container for trivially copyable elements suitable for use in - unions and shared memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Modifies semantics of cutlass::Array<> to provide guaranteed alignment. -template < - typename T, - int N, - int Align = 16 -> -struct AlignedBuffer { - - /// Internal storage type - using Storage = uint8_t; - - /// Number of logical elements held in buffer - static int const kCount = N; - - /// Alignment requirement in bytes - static int const kAlign = Align; - - /// Number of storage elements - static int const kBytes = - (sizeof_bits::value * N + 7) / 8; - -private: - - /// Internal storage - alignas(Align) Storage storage[kBytes]; - -public: - - // - // C++ standard members - // - - typedef T value_type; - typedef size_t size_type; - typedef ptrdiff_t difference_type; - typedef value_type *pointer; - typedef value_type const * const_pointer; - - using Array = Array; - using reference = typename Array::reference; - using const_reference = typename Array::const_reference; - -public: - - CUTLASS_HOST_DEVICE - pointer data() { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - const_pointer data() const { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - Storage * raw_data() { - return storage; - } - - CUTLASS_HOST_DEVICE - Storage const * raw_data() const { - return storage; - } - - - CUTLASS_HOST_DEVICE - constexpr bool empty() const { - return !kCount; - } - - CUTLASS_HOST_DEVICE - constexpr size_type size() const { - return kCount; - } - - CUTLASS_HOST_DEVICE - constexpr size_type max_size() const { - return kCount; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/arch.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/arch.h deleted file mode 100644 index 5c3bd8a7755aea9e74fce2382b4d60250aec9ab1..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/arch.h +++ /dev/null @@ -1,129 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines tags for architecture-specific configurations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -constexpr int sm100_smem_capacity_bytes = 232448; -constexpr int sm120_smem_capacity_bytes = 101376; - -#if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) - -/// Computes laneId within a warp -CUTLASS_DEVICE -int LaneId() { - int ret; - asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); - return ret; -} - -/// Computes SM number the thread is running on -CUTLASS_DEVICE -int SmId() { - int ret; - asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); - return ret; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// -struct Sm50 { - static int const kMinComputeCapability = 50; -}; -struct Sm60 { - static int const kMinComputeCapability = 60; -}; -struct Sm61 { - static int const kMinComputeCapability = 61; -}; -struct Sm70 { - static int const kMinComputeCapability = 70; -}; -struct Sm72 { - static int const kMinComputeCapability = 72; -}; -struct Sm75 { - static int const kMinComputeCapability = 75; -}; -struct Sm80 { - static int const kMinComputeCapability = 80; -}; -struct Sm86 { - static int const kMinComputeCapability = 86; -}; -struct Sm89 { - static int const kMinComputeCapability = 89; -}; -struct Sm90 { - static int const kMinComputeCapability = 90; -}; - - -struct Sm100 { - static int const kMinComputeCapability = 100; -}; - -struct Sm101 { - static int const kMinComputeCapability = 101; -}; - -struct Sm120 { - static int const kMinComputeCapability = 120; -}; - -struct Sm103 { - static int const kMinComputeCapability = 103; -}; - -/// Triggers a breakpoint on the device -CUTLASS_DEVICE -void device_breakpoint() { -#if defined(__CUDA_ARCH__) - asm volatile (" brkpt;\n"); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/barrier.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/barrier.h deleted file mode 100644 index 2b0d4bb64078553b496eec7992d11f81ef4fc33e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/barrier.h +++ /dev/null @@ -1,912 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Barrier Operations on SM90+ -*/ - -#pragma once - -#include -#include -#include -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) -#define CUDA_BARRIER_ENABLED 1 -#else -#define CUDA_BARRIER_ENABLED 0 -#endif - - -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) -#define CUTLASS_ARCH_TCGEN_ENABLED 1 -#endif - -#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) -#define CUTLASS_ARCH_TCGEN_ENABLED 1 -#endif - -namespace cutlass { -/// @brief -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -CUTLASS_DEVICE void fence_view_async_shared(); - -namespace detail { // namespace detail begin - -// Single threaded versions that need to be called in an elect_one region -template -CUTLASS_DEVICE -void initialize_barrier_array(T ptr, int arv_cnt) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - ptr[i].init(arv_cnt); - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array(uint64_t *ptr, int arv_cnt) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - T::init(&ptr[i], arv_cnt); - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array_pair(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - full_barriers[i].init(full_barrier_arv_cnt); - empty_barriers[i].init(empty_barrier_arv_cnt); - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array_pair(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); - EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); - } -} - -// Aligned versions that need to be call warp wide -template -CUTLASS_DEVICE -void initialize_barrier_array_aligned(T ptr, int arv_cnt) { - if(cute::elect_one_sync()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - ptr[i].init(arv_cnt); - } - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array_aligned(uint64_t *ptr, int arv_cnt) { - if(cute::elect_one_sync()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - T::init(&ptr[i], arv_cnt); - } - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array_pair_aligned(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { - if(cute::elect_one_sync()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - full_barriers[i].init(full_barrier_arv_cnt); - empty_barriers[i].init(empty_barrier_arv_cnt); - } - } -} - -template -CUTLASS_DEVICE -void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { - if(cute::elect_one_sync()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Stages; i++) { - FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); - EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); - } - } -} - -} // namespace detail end - - - - -// There are 16 Named Barriers provided by Hardware starting in Hopper -// Their IDs are in the range 0-15 -// Number of threads syncing using the barrier must be a multiple of warp-size -// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads) -// may use it and conflict with other uses. - - -// Enumerates the reserved named barriers to avoid potential conflicts -// This enum class specifies the NamedBarriers reserved by CUTLASS. -enum class ReservedNamedBarriers { - EpilogueBarrier = 1, - TransposeBarrier = 2, - TransformBarrier = 3, - StreamkBarrier0 = 4, - StreamkBarrier1 = 5 - , TmemAllocBarrier = 6 - , Sm120MainloopBarrier = 7 - , FirstUserBarrier = Sm120MainloopBarrier + 1 -}; - - -class NamedBarrier { - - // Data Members: - - // Range = [1 , NUM_THREADS_PER_CTA] - // Range % warp-size (i.e 32) == 0 - uint32_t const num_threads_; - - // Range : [0, 15] - // Note that should be set to the final barrier ID, including ReserveNamedBarrierCount should be considered - uint32_t const id_; - - public: - - // Constructor for CUTLASS developers: - // effective barrier ID starts from 0 - CUTLASS_DEVICE - NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) - : num_threads_(num_threads), id_(static_cast(reserved_named_barriers)) {} - - // Constructor for CUTLASS users: - // effective barrier ID starts from ReservedNamedBarrierCount - CUTLASS_DEVICE - NamedBarrier(uint32_t num_threads, uint32_t id = 0) - : num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) { - CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16."); - } - - CUTLASS_DEVICE - void arrive_and_wait() const { - // Note: The value of id_ is already the final barrier id (set correctly in the constructor). - NamedBarrier::arrive_and_wait_internal(num_threads_, id_); - } - - CUTLASS_DEVICE - void arrive_and_wait_unaligned() const { - // Note: The value of id_ is already the final barrier id (set correctly in the constructor). - NamedBarrier::arrive_and_wait_internal_unaligned(num_threads_, id_); - } - - CUTLASS_DEVICE - void arrive() const { - // Note: The value of id_ is already the final barrier id (set correctly in the constructor). - NamedBarrier::arrive_internal(num_threads_, id_); - } - - CUTLASS_DEVICE - void arrive_unaligned() const { - // Note: The value of id_ is already the final barrier id (set correctly in the constructor). - NamedBarrier::arrive_internal_unaligned(num_threads_, id_); - } - - CUTLASS_DEVICE - void sync() const { - NamedBarrier::arrive_and_wait(); - } - - // Static variants - - // Calling interface for CUTLASS users: - // effective barrier ID starts from ReservedNamedBarrierCount - CUTLASS_DEVICE - static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { - arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount); - } - - // Calling interface for CUTLASS developers: - // effective barrier ID starts from 0 - CUTLASS_DEVICE - static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { - arrive_and_wait_internal(num_threads, static_cast(reserved_named_barriers)); - } - - // Calling interface for CUTLASS users: - // effective barrier ID starts from ReservedNamedBarrierCount - CUTLASS_DEVICE - static void arrive(uint32_t num_threads, uint32_t barrier_id) { - arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount); - } - - // Calling interface for CUTLASS developers: - // effective barrier ID starts from 0 - CUTLASS_DEVICE - static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { - arrive_internal(num_threads, static_cast(reserved_named_barriers)); - } - - // Calling interface for CUTLASS users: - // effective barrier ID starts from ReservedNamedBarrierCount - CUTLASS_DEVICE - static void sync(uint32_t num_threads, uint32_t barrier_id) { - sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount); - } - - // Calling interface for CUTLASS developers: - // effective barrier ID starts from 0 - CUTLASS_DEVICE - static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { - sync_internal(num_threads, static_cast(reserved_named_barriers)); - } - - - private: - CUTLASS_DEVICE - static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED - asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); - cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_DEVICE - static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED - asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); - cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_DEVICE - static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED - cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_DEVICE - static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED - cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_DEVICE - static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { - NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); - } - - public: - // Currently we reserve 8 NamedBarriers for CUTLASS' own use cases, - // while leaving the renaming for general users. - static const uint32_t ReservedNamedBarrierCount = static_cast(ReservedNamedBarriers::FirstUserBarrier); - static const uint32_t HardwareMaxNumNamedBarriers = 16; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide arrive-wait behaviour. -// This is an extension to the Ampere arrive-wait barriers -// Note : Ampere arrive-wait Barriers have a larger max-arrive count (2^30) than Hopper arrive-wait Barriers (2^20). -struct ClusterBarrier { - - using ValueType = uint64_t; - -protected: - // Can never be initialized - can only be aliased to smem - ValueType barrier_; - -public: - - CUTLASS_DEVICE - ClusterBarrier() = delete; - - CUTLASS_DEVICE - void init(uint32_t arrive_count) const { - ClusterBarrier::init(&this->barrier_, arrive_count); - } - - CUTLASS_DEVICE - bool test_wait(uint32_t phase, uint32_t pred=true) const { - return ClusterBarrier::test_wait(&this->barrier_, phase, pred); - } - - CUTLASS_DEVICE - bool try_wait(uint32_t phase) const { - return ClusterBarrier::try_wait(&this->barrier_, phase); - } - - CUTLASS_DEVICE - void wait(uint32_t phase) const { - ClusterBarrier::wait(&this->barrier_, phase); - } - - // Barrier arrive on local smem - CUTLASS_DEVICE - void arrive() const { - ClusterBarrier::arrive(&this->barrier_); - } - - // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) - CUTLASS_DEVICE - void arrive(uint32_t cta_id, uint32_t pred = true ) const { - ClusterBarrier::arrive(&this->barrier_, cta_id, pred); - } - - // - // Static Versions - // - CUTLASS_HOST_DEVICE - static void init(ValueType const* smem_ptr, uint32_t arrive_count) { - CUTLASS_ASSERT(arrive_count != 0 && "Arrive count must be non-zero"); -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "mbarrier.init.shared::cta.b64 [%1], %0; \n" - "}" - : - : "r"(arrive_count), "r"(smem_addr)); - cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // Static version of wait - in case we don't want to burn a register - CUTLASS_HOST_DEVICE - static void wait(ValueType const* smem_ptr, uint32_t phase) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - cutlass::arch::synclog_emit_cluster_barrier_wait(__LINE__, smem_addr, phase); - // Arbitrarily large timer value after which try-wait expires and re-tries. - uint32_t ticks = 0x989680; - asm volatile( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra DONE; \n\t" - "bra LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - : - : "r"(smem_addr), "r"(phase), "r"(ticks)); - -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_HOST_DEVICE - static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - cutlass::arch::synclog_emit_cluster_barrier_test_wait(__LINE__, smem_addr, phase, pred); - uint32_t waitComplete; - - asm volatile( - "{\n\t" - ".reg .pred P1; \n\t" - ".reg .pred P2; \n\t" - "setp.eq.u32 P2, %3, 1;\n\t" - "@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}" - : "=r"(waitComplete) - : "r"(smem_addr), "r"(phase), "r"(pred)); - - return static_cast(waitComplete); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - return 0; - } - - CUTLASS_HOST_DEVICE - static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - cutlass::arch::synclog_emit_cluster_barrier_try_wait(__LINE__, smem_addr, phase); - uint32_t waitComplete; - - asm volatile( - "{\n\t" - ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}" - : "=r"(waitComplete) - : "r"(smem_addr), "r"(phase)); - - return static_cast(waitComplete); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - return 0; - } - - // Static Predicated version of the above - in case we know the address. - CUTLASS_HOST_DEVICE - static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - if (pred) { - asm volatile( - "{\n\t" - ".reg .b32 remAddr32;\n\t" - "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" - "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" - "}" - : - : "r"(smem_addr), "r"(cta_id)); - } - - cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // Barrier arrive on local smem - CUTLASS_HOST_DEVICE - static void arrive(ValueType const* smem_ptr) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" - "}" - : - : "r"(smem_addr)); - cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - CUTLASS_HOST_DEVICE - static void invalidate(ValueType const* smem_ptr) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "mbarrier.inval.shared::cta.b64 [%0]; \n\t" - "}" - : - : "r"(smem_addr)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SM90 also introduces a new type of cluster-barrier which supports sync. -// not just based on Arrive Count, but also transaction count (in bytes) -struct ClusterTransactionBarrier : public ClusterBarrier { - - CUTLASS_DEVICE - ClusterTransactionBarrier() = delete; - - // Performs an arrive operation + expected transaction bytes increment - CUTLASS_DEVICE - void arrive_and_expect_tx(uint32_t transaction_bytes) const { - ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes); - } - - // Performs an arrive operation + expected transaction bytes increment - CUTLASS_DEVICE - void arrive_and_expect_tx(uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred = 1u) const { - ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes , cta_id, pred); - } - - // Performs an expected transaction bytes increment without doing an arrive operation - CUTLASS_DEVICE - void expect_transaction(uint32_t transaction_bytes) const { - ClusterTransactionBarrier::expect_transaction(&this->barrier_, transaction_bytes); - } - - // Performs an expected transaction bytes decrement without doing an arrive operation - CUTLASS_DEVICE - void complete_transaction(uint32_t transaction_bytes, uint32_t pred = 1) const { - uint32_t cta_rank = cute::block_rank_in_cluster(); - ClusterTransactionBarrier::complete_transaction(&this->barrier_, cta_rank, transaction_bytes, pred); - } - - // Performs an expected transaction bytes decrement without doing an arrive operation - CUTLASS_DEVICE - void complete_transaction(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { - ClusterTransactionBarrier::complete_transaction(&this->barrier_, dst_cta_id, transaction_bytes, pred); - } - - // - // Static Versions - // - - // Performs an arrive operation + expected transaction bytes increment - CUTLASS_HOST_DEVICE - static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" - "}" - : - : "r"(transaction_bytes), "r"(smem_addr)); - cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(__LINE__, smem_addr, transaction_bytes); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // Performs an arrive operation + expected transaction bytes increment for a remote cta_id in a Cluster - CUTLASS_HOST_DEVICE - static void arrive_and_expect_tx( - ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - ".reg .b32 remAddr32;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" - "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" - "}" - : - : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // Performs an expected transaction bytes increment without doing an arrive operation - CUTLASS_HOST_DEVICE - static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" - "}" - : - : "r"(transaction_bytes), "r"(smem_addr)); - cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, transaction_bytes); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // Performs an expected transaction bytes decrement without doing an arrive operation - CUTLASS_HOST_DEVICE - static void complete_transaction( - ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" - "}" - : - : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); - cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction(__LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif - } - - // - // DEPRECATED APIs - // - [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE - void arrive_and_reset_bytes(uint32_t transaction_bytes) const { - arrive_and_expect_tx(transaction_bytes); - } - [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE - void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { - arrive_and_expect_tx(transaction_bytes, cta_id); - } - [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE - void reset_bytes(uint32_t transaction_bytes) const { - expect_transaction(transaction_bytes); - } - [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE - void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { - complete_transaction(transaction_bytes, pred); - } - [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE - void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { - complete_transaction(dst_cta_id, transaction_bytes, pred); - } - [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE - static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { - arrive_and_expect_tx(smem_ptr, transaction_bytes); - } - [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE - static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { - arrive_and_expect_tx(smem_ptr, transaction_bytes, cta_id, pred); - } - [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE - static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { - expect_transaction(smem_ptr, transaction_bytes); - } - [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE - static void commit(ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { - complete_transaction(smem_ptr, dst_cta_id, transaction_bytes, pred); - } -}; - -// Helps with visibility of barrier init operations across warps / cta / cluster -// Available as a separate function so as to batch inits across barriers and fence once -// Note : It must be composed with an appropriate sync instruction with the right scope -// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() -CUTLASS_DEVICE -void fence_barrier_init() { -#if CUDA_BARRIER_ENABLED - cutlass::arch::synclog_emit_fence_barrier_init(__LINE__); - asm volatile( - "{\n\t" - "fence.mbarrier_init.release.cluster; \n" - "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// Issue a shared memory fence for async operations -CUTLASS_DEVICE -void fence_view_async_shared() { -#if CUDA_BARRIER_ENABLED - cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); - asm volatile ( - "{\n\t" - "fence.proxy.async.shared::cta; \n" - "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// Arrive on completion of in-flight cp.async operations issued by the calling thread -CUTLASS_HOST_DEVICE -void cpasync_barrier_arrive(uint64_t const* smem_ptr) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" - "}" - : - : "r"(smem_addr)); - cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// Arrive on completion of in-flight cp.async operations issued by the calling thread (noinc) -CUTLASS_HOST_DEVICE -void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) { -#if CUDA_BARRIER_ENABLED - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" - "}" - : - : "r"(smem_addr)); - cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -CUTLASS_HOST_DEVICE -void umma_arrive(uint64_t const* smem_ptr) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - if (cute::elect_one_sync()) { - asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" - : - :"r"(bar_intptr)); - } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -//UMMA arrive for MMA_2x1SM -CUTLASS_HOST_DEVICE -void umma_arrive_2x1SM(uint64_t const* smem_ptr) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - if (cute::elect_one_sync()) { - asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" - : - :"r"(bar_intptr)); - } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// UMMA arrive for MMA_1sm + TMA_LOAD_MULTICAST combination -CUTLASS_HOST_DEVICE -void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - if(cute::elect_one_sync()) { - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); - } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination -CUTLASS_HOST_DEVICE -void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - if (cute::elect_one_sync()) { - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); - } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -// Temporary solution for sparse kernel. -// Will remove this when we done tightly elect_one wrap. -CUTLASS_HOST_DEVICE -void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - ".reg .b16 lo, hi;\n\t" - "mov.b32 {lo, hi}, %1;\n\t" - "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" - "}" - : - :"r"(bar_intptr), "r"(uint32_t(cta_mask))); -#elif defined(__CUDA_ARCH__) - CUTLASS_NOT_IMPLEMENTED(); -#endif -} - -// Temporary solution for sparse kernel. -// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination -CUTLASS_HOST_DEVICE -void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - ".reg .b16 lo, hi;\n\t" - "mov.b32 {lo, hi}, %1;\n\t" - "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" - "}" - : - :"r"(bar_intptr), "r"(uint32_t(cta_mask))); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif -} - -// Always arrive on even SM of collaborating 2 SMs. -CUTLASS_HOST_DEVICE -void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr) & cute::Sm100MmaPeerBitMask; - asm volatile ( - "{\n\t" - "mbarrier.arrive.shared::cluster.b64 _, [%0];\n\t" - "}" - : - : "r"(bar_intptr)); - -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -CUTE_DEVICE static void fence_view_async_tmem_load() { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - asm volatile ( - "{\n\t" - "tcgen05.wait::ld.sync.aligned; \n" - "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - -CUTE_DEVICE static void fence_view_async_tmem_store() { -#if defined(CUTLASS_ARCH_TCGEN_ENABLED) - asm volatile ( - "{\n\t" - "tcgen05.wait::st.sync.aligned; \n" - "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); -#endif -} - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -} // end namespace arch -} // end namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/cache_operation.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/cache_operation.h deleted file mode 100644 index 5128ee02cec37a3718226a48d6c72edae58b5c09..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/cache_operation.h +++ /dev/null @@ -1,66 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Directives related to cache operations -*/ -#pragma once - -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Controls PTX cache operations -struct CacheOperation { - enum Kind { - /// Cache at all levels - accessed again - Always, - /// Cache at global level - Global, - /// Streaming - likely to be accessed once - Streaming, - /// Indicates the line will not be used again - LastUse, - /// Don't cache, and fetch again - Volatile, - /// Write back at all coherent levels - WriteBack, - /// Write through to system memory - WriteThrough - }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/config.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/config.h deleted file mode 100644 index 873e64375dc8e404f715fa3fe2c11c0eced85a1e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/config.h +++ /dev/null @@ -1,230 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Definitions for architecture macros -*/ - -#pragma once - -#include "cutlass/platform/platform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM90 -#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 0)) - #define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) - #define CUTLASS_ARCH_MMA_SM90_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) - #define CUTLASS_ARCH_MMA_SM90A_ENABLED 1 - #endif - #endif -#endif - -#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 2)) - #define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Modifiable TMA -// tensormap.replace is arch conditional -#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) - #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && \ - (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL) || \ - defined(__CUDA_ARCH_FEAT_SM101_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL))) - #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 - #endif -#endif - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM90 F64 -#if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED 1 - #endif -#endif - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM100, SM100a -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) - #define CUTLASS_ARCH_MMA_SM100_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM100_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000) - #define CUTLASS_ARCH_MMA_SM100_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL)) - #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 - #endif - - // SM100f - #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM100F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) && CUDA_ARCH_FAMILY(1000)) - #define CUTLASS_ARCH_MMA_SM100F_ENABLED CUTLASS_ARCH_MMA_SM100F_SUPPORTED - #endif - #endif -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM101 and SM101a -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) - #define CUTLASS_ARCH_MMA_SM101_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM101_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1010) - #define CUTLASS_ARCH_MMA_SM101_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM101_ALL)) - #define CUTLASS_ARCH_MMA_SM101A_ENABLED 1 - #endif - - // SM101f - #if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) - #define CUTLASS_ARCH_MMA_SM101F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) && CUDA_ARCH_FAMILY(1010)) - #define CUTLASS_ARCH_MMA_SM101F_ENABLED CUTLASS_ARCH_MMA_SM101F_SUPPORTED - #endif - #endif -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM110 and SM110a only on 13.0 and above -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) - #define CUTLASS_ARCH_MMA_SM110_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM110_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1100) - #define CUTLASS_ARCH_MMA_SM110_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM110_ALL)) - #define CUTLASS_ARCH_MMA_SM110A_ENABLED 1 - #endif - - // SM110f - #if (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) - #define CUTLASS_ARCH_MMA_SM110F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) && CUDA_ARCH_FAMILY(1100)) - #define CUTLASS_ARCH_MMA_SM110F_ENABLED CUTLASS_ARCH_MMA_SM110F_SUPPORTED - #endif - #endif -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// SM120 and SM120a -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) - #define CUTLASS_ARCH_MMA_SM120_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM120_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200) - #define CUTLASS_ARCH_MMA_SM120_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM120_ALL)) - #define CUTLASS_ARCH_MMA_SM120A_ENABLED 1 - #endif - - // SM120f - #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM120F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) && CUDA_ARCH_FAMILY(1200)) - #define CUTLASS_ARCH_MMA_SM120F_ENABLED CUTLASS_ARCH_MMA_SM120F_SUPPORTED - #endif - #endif -#endif - -// SM103 and SM103a -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM103_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM103_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1030) - #define CUTLASS_ARCH_MMA_SM103_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM103_ALL)) - #define CUTLASS_ARCH_MMA_SM103A_ENABLED 1 - #endif - - // SM103f - #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM103F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) && CUDA_ARCH_FAMILY(1030)) - #define CUTLASS_ARCH_MMA_SM103F_ENABLED CUTLASS_ARCH_MMA_SM103F_SUPPORTED - #endif - #endif -#endif - -// SM121 and SM121a -#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM121_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_SM121_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1210) - #define CUTLASS_ARCH_MMA_SM121_ENABLED 1 - - #if (!defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) &&\ - (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL(1210))) - #define CUTLASS_ARCH_MMA_SM121A_ENABLED 1 - #endif - - // SM121f - #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) - #define CUTLASS_ARCH_MMA_SM121F_SUPPORTED 1 - #endif - - #if (!defined(CUTLASS_ARCH_MMA_SM121F_ENABLED) && CUDA_ARCH_FAMILY(1210)) - #define CUTLASS_ARCH_MMA_SM121F_ENABLED CUTLASS_ARCH_MMA_SM121F_SUPPORTED - #endif - #endif -#endif - - -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) -# define CUTLASS_ARCH_CLC_ENABLED -#endif - - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/grid_dependency_control.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/grid_dependency_control.h deleted file mode 100644 index e5e99c855871787b6e1820895940cf8775caba25..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/grid_dependency_control.h +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Grid dependent control (GDC) helpers for programmatic dependent launches (PDL). -*/ - -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/barrier.h" -#include "cutlass/conv/dispatch_policy.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" - -#ifndef CUTLASS_GDC_ENABLED - #if (CUDA_BARRIER_ENABLED && \ - defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ - __CUDACC_VER_MAJOR__ >= 12 && \ - defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) - #define CUTLASS_GDC_ENABLED - #endif - #if (defined(CUTLASS_ENABLE_GDC_FOR_SM100) && \ - __CUDACC_VER_MAJOR__ >= 12 && \ - defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) - #define CUTLASS_GDC_ENABLED - #endif -#endif - -#ifndef CUTLASS_GDC_ENABLED - #if(CUDA_BARRIER_ENABLED && \ - defined(CUTLASS_ENABLE_GDC_FOR_SM100) && \ - defined(__CUDA_ARCH__) && \ - ((__CUDA_ARCH__ == 1000 &&\ - (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ - (__CUDA_ARCH__ == 1010 &&\ - (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ - (__CUDA_ARCH__ == 1030 &&\ - (defined(__CUDA_ARCH_FEAT_SM103_ALL) || CUDA_ARCH_FAMILY(1030))) || \ - (__CUDA_ARCH__ == 1200 &&\ - (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))) || \ - (__CUDA_ARCH__ == 1210 &&\ - (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210))))) - #define CUTLASS_GDC_ENABLED - #endif -#endif - -namespace cutlass { -namespace arch { - -// Issuing the launch_dependents instruction hints a dependent kernel to launch earlier -// launch_dependents doesn't impact the functionality but the performance: -// Launching a dependent kernel too early can compete with current kernels, -// while launching too late can lead to a long latency. -CUTLASS_DEVICE -void launch_dependent_grids() { -#if (defined(CUTLASS_GDC_ENABLED)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -// Issuing the griddepcontrol.wait instruction enforces no global memory access -// prior to this istruction. This ensures the correctness of global memory access -// when launching a dependent kernel earlier. -CUTLASS_DEVICE -void wait_on_dependent_grids() { -#if (defined(CUTLASS_GDC_ENABLED)) - asm volatile("griddepcontrol.wait;"); -#endif -} - -// Enable kernel-level query regarding whether the GDC feature is turned on -#if (defined(CUTLASS_GDC_ENABLED)) -static constexpr bool IsGdcGloballyEnabled = true; -#else -static constexpr bool IsGdcGloballyEnabled = false; -#endif - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory.h deleted file mode 100644 index 0fb47b1744137a6f64097ad2879da74d160ccc9c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory.h +++ /dev/null @@ -1,602 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Architecture-specific operators on memory -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/platform/platform.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Fragment type to store loaded data - typename AccessType, - /// The bytes of loading - int LoadBytes, - /// Cache operation - CacheOperation::Kind cache_op = CacheOperation::Always - > -struct global_load; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ - (__CUDACC_VER_MAJOR__ > 11)) && \ - defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) - #define CUTLASS_ENABLE_L2_PREFETCH 1 -#else - #define CUTLASS_ENABLE_L2_PREFETCH 0 -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// The redundant mov PTX instruction is used to enforce the compiler to -// keep the initializing code before ld.global -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint4 *data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %9, 0;\n" - " mov.b32 %0, %10;\n" - " mov.b32 %1, %11;\n" - " mov.b32 %2, %12;\n" - " mov.b32 %3, %13;\n" - " mov.b32 %4, %14;\n" - " mov.b32 %5, %15;\n" - " mov.b32 %6, %16;\n" - " mov.b32 %7, %17;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n" - " @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n" -#else - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" - " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" -#endif - "}\n" - : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), - "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) - : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), - "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), - "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint4 *data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %9, 0;\n" - " mov.b32 %0, %10;\n" - " mov.b32 %1, %11;\n" - " mov.b32 %2, %12;\n" - " mov.b32 %3, %13;\n" - " mov.b32 %4, %14;\n" - " mov.b32 %5, %15;\n" - " mov.b32 %6, %16;\n" - " mov.b32 %7, %17;\n" - " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%8];\n" - " @p ld.global.lu.v4.u32 {%4, %5, %6, %7}, [%18];\n" - "}\n" - : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), - "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) - : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), - "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), - "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint4 &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " mov.b32 %0, %6;\n" - " mov.b32 %1, %7;\n" - " mov.b32 %2, %8;\n" - " mov.b32 %3, %9;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n" -#else - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" -#endif - "}\n" - : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint4 &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " mov.b32 %0, %6;\n" - " mov.b32 %1, %7;\n" - " mov.b32 %2, %8;\n" - " mov.b32 %3, %9;\n" - " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%4];\n" - "}\n" - : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint2 &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " mov.b32 %0, %4;\n" - " mov.b32 %1, %5;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n" -#else - " @p ld.global.v2.u32 {%0, %1}, [%2];\n" -#endif - "}\n" - : "=r"(data.x), "=r"(data.y) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint2 &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " mov.b32 %0, %4;\n" - " mov.b32 %1, %5;\n" - " @p ld.global.lu.v2.u32 {%0, %1}, [%2];\n" - "}\n" - : "=r"(data.x), "=r"(data.y) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - unsigned &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b32 %0, %3;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p ld.global.L2::128B.u32 %0, [%1];\n" -#else - " @p ld.global.u32 %0, [%1];\n" -#endif - "}\n" - : "=r"(data) - : "l"(ptr), "r"((int)pred_guard), "r"(data)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - unsigned &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b32 %0, %3;\n" - " @p ld.global.lu.u32 %0, [%1];\n" - "}\n" - : "=r"(data) - : "l"(ptr), "r"((int)pred_guard), "r"(data)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint16_t &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b16 %0, %3;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p ld.global.L2::128B.u16 %0, [%1];\n" -#else - " @p ld.global.u16 %0, [%1];\n" -#endif - "}\n" - : "=h"(data) - : "l"(ptr), "r"((int)pred_guard), "h"(data)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - uint16_t &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b16 %0, %3;\n" - " @p ld.global.lu.u16 %0, [%1];\n" - "}\n" - : "=h"(data) - : "l"(ptr), "r"((int)pred_guard), "h"(data)); - } -}; - -template -struct global_load { - CUTLASS_DEVICE - global_load(AccessType &D, void const *ptr, bool pred_guard) { - if (pred_guard) D = *(reinterpret_cast(ptr)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Fragment type to store data - typename AccessType, - /// The bytes of storing - int StoreBytes - > -struct global_store; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint4 const *data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" - " @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n" - " @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n" - "}\n" - : - : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), - "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), - "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w), - "l"(((uint8_t *)ptr) + 32), - "r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w), - "l"(((uint8_t *)ptr) + 48), - "r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[3].w)); - } -}; - - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint4 const *data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" - "}\n" - : - : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), - "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), - "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); - } -}; - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint4 const &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - "}\n" - : - : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint2 const &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " @p st.global.v2.u32 [%0], {%1, %2};\n" - "}\n" - : - : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint32_t const &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p st.global.u32 [%0], %1;\n" - "}\n" - : - : "l"(ptr), "r"(data), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - uint16_t const &data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p st.global.u16 [%0], %1;\n" - "}\n" - : - : "l"(ptr), "h"(data), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - CUTLASS_DEVICE - global_store(AccessType const &D, void *ptr, bool pred_guard) { - if (pred_guard) *(reinterpret_cast(ptr)) = D; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// ld.shared -template -CUTLASS_DEVICE -void shared_load(void *dst, uint32_t ptr); - -/// ld.shared - 16b -template <> -CUTLASS_DEVICE -void shared_load<2>(void *dst, uint32_t ptr) { - asm volatile("ld.shared.u16 %0, [%1];\n" - : "=h"(*reinterpret_cast(dst)) - : "r"(ptr)); -} - -/// ld.shared - 32b -template <> -CUTLASS_DEVICE -void shared_load<4>(void *dst, uint32_t ptr) { - asm volatile("ld.shared.u32 %0, [%1];\n" - : "=r"(*reinterpret_cast(dst)) - : "r"(ptr)); -} - -/// ld.shared - 64b -template <> -CUTLASS_DEVICE -void shared_load<8>(void *dst, uint32_t ptr) { - uint2 *dst_u64 = reinterpret_cast(dst); - asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" - : - "=r"(dst_u64->x), - "=r"(dst_u64->y) - : "r"(ptr)); -} - -/// ld.shared - 128b -template <> -CUTLASS_DEVICE -void shared_load<16>(void *dst, uint32_t ptr) { - uint4 *dst_u128 = reinterpret_cast(dst); - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" - : - "=r"(dst_u128->x), - "=r"(dst_u128->y), - "=r"(dst_u128->z), - "=r"(dst_u128->w) - : "r"(ptr)); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// st.shared -template -CUTLASS_DEVICE -void shared_store(uint32_t ptr, void const *src); - -/// st.shared - 16b -template <> -CUTLASS_DEVICE -void shared_store<2>(uint32_t ptr, void const *src) { - asm volatile("st.shared.u16 [%0], %1;\n" - : : - "r"(ptr), - "h"(*reinterpret_cast(src)) - ); -} - -/// st.shared - 32b -template <> -CUTLASS_DEVICE -void shared_store<4>(uint32_t ptr, void const *src) { - asm volatile("st.shared.u32 [%0], %1;\n" - : : - "r"(ptr), - "r"(*reinterpret_cast(src)) - ); -} - -/// st.shared - 64b -template <> -CUTLASS_DEVICE -void shared_store<8>(uint32_t ptr, void const *src) { - uint2 const *dst_u64 = reinterpret_cast(src); - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" - : : - "r"(ptr), - "r"(dst_u64->x), - "r"(dst_u64->y) - ); -} - -/// st.shared - 128b -template <> -CUTLASS_DEVICE -void shared_store<16>(uint32_t ptr, void const *src) { - uint4 const *dst_u128 = reinterpret_cast(src); - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" - : : - "r"(ptr), - "r"(dst_u128->x), - "r"(dst_u128->y), - "r"(dst_u128->z), - "r"(dst_u128->w) - ); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/memory_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm75.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm75.h deleted file mode 100644 index 040f70743610da68344342a4e427f9182bc14c68..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm75.h +++ /dev/null @@ -1,270 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Architecture-specific operators on memory added for SM75 -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/layout/matrix.h" -#include "cute/arch/copy_sm75.hpp" -#include "cute/arch/util.hpp" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Layout of destination matrix (column-major implies transpose) - typename Layout, - /// .x1, .x2, or .x4 - int MatrixCount -> -CUTLASS_DEVICE void ldsm(Array & D, void const* ptr); - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Determine the appropriate way to target PTX's "ldmatrix" instruction. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// CUTLASS helper to get SMEM pointer -CUTLASS_HOST_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { - return cute::cast_smem_ptr_to_uint(ptr); -} - -/// CUTLASS helper to get SMEM pointer -CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void const *ptr) { - return cutlass_get_smem_pointer(const_cast(ptr)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x; - asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); - reinterpret_cast(D) = x; - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x, y; - asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); - reinterpret_cast(D) = make_int2(x, y); - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x, y, z, w; - asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); - reinterpret_cast(D) = make_int4(x, y, z, w); - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Transpose on 16b granularity -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x; - asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); - reinterpret_cast(D) = x; - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x, y; - asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); - reinterpret_cast(D) = make_int2(x, y); - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_DEVICE void ldsm( - Array & D, - void const* ptr) { - - #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) - - unsigned addr = cutlass_get_smem_pointer(ptr); - - int x, y, z, w; - asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); - reinterpret_cast(D) = make_int4(x, y, z, w); - - #else - - CUTLASS_UNUSED(D); - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct shared_load_op { - CUTLASS_DEVICE - shared_load_op(AccessType &D, void const *ptr) { - D = *reinterpret_cast(ptr); - } -}; - -template -CUTLASS_DEVICE void shared_load(AccessType &D, void const *ptr) { - shared_load_op(D, ptr); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct shared_load_op { - CUTLASS_DEVICE - shared_load_op(AccessType &D, void const *ptr) { - unsigned addr = cutlass_get_smem_pointer(ptr); - - uint4 v; - asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : - "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) : "r"(addr)); - - D = reinterpret_cast(v); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct shared_load_op { - CUTLASS_DEVICE - shared_load_op(AccessType &D, void const *ptr) { - unsigned addr = cutlass_get_smem_pointer(ptr); - - uint2 v; - asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : - "=r"(v.x), "=r"(v.y) : "r"(addr)); - - D = reinterpret_cast(v); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm80.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm80.h deleted file mode 100644 index b91a198b820dd374ba37a30c4f90d52cffbd95a7..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/memory_sm80.h +++ /dev/null @@ -1,473 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Architecture-specific operators on memory added for SM80 -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/arch/synclog.hpp" - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - #define CUDA_CP_ASYNC_ACTIVATED 1 -#else - #define CUDA_CP_ASYNC_ACTIVATED 0 -#endif - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Initiates an asynchronous copy from global memory to shared memory. -/// -/// cp.async -/// -template < - /// Size of the access in bytes - int SizeInBytes, - /// Cache operation - CacheOperation::Kind cache_op = CacheOperation::Always> -struct cp_async; - -/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -/// the entire transfer, zeros are written to SMEM if the guard predicate is false. -/// -/// cp.async -/// -template < - /// Size of the access in bytes - int SizeInBytes, - /// Cache operation - CacheOperation::Kind cache_op = CacheOperation::Always> -struct cp_async_zfill; - -/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -/// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false. -/// -/// cp.async -/// -template < - /// Size of the access in bytes - int SizeInBytes, - /// Cache operation - CacheOperation::Kind cache_op = CacheOperation::Always> -struct cp_async_nan; - -/// Either 0 or 1 are written to SMEM based on input element type -/// Used for diagonal elements of triangular matrix of BLAS3 functions -/// -/// st.shared -/// -template < - /// Type of Element - typename Element, - /// If the data is for a Hermitian matrix diagonal - bool IsHermitianData = false> -struct cp_async_diag; - -static const uint32_t OOB_NAN_F16 = 0x7eff; -static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization -template < - /// Size of the access in bytes - int SizeInBytes> -struct cp_async { - - /// Copy - CUTLASS_DEVICE - cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - #if CUDA_CP_ASYNC_ACTIVATED - - // Make sure the size is supported. - static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), - "Size is not supported"); - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -#else - " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -#endif - "}\n" ::"r"((int)pred_guard), - "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); - - #else - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - #endif - } -}; - -/// Partial specialization -template < - /// Size of the access in bytes - int SizeInBytes> -struct cp_async_zfill { - - /// Copy with zero fill - CUTLASS_DEVICE - cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { - #if CUDA_CP_ASYNC_ACTIVATED - - // Make sure the size is supported. - static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), - "Size is not supported"); - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - int src_in_bytes = (pred_guard ? SizeInBytes : 0); - - asm volatile( -#if CUTLASS_ENABLE_L2_PREFETCH - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -#else - "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -#endif - "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); - - #else - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } - #endif - } -}; - -/// Partial specialization -template <> -struct cp_async_nan<16, CacheOperation::Always> { - static int const kSizeInBytes = 16; - - /// Copy with nan fill - CUTLASS_DEVICE - cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { - #if CUDA_CP_ASYNC_ACTIVATED - - static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, - OOB_NAN_F16x2, OOB_NAN_F16x2}; - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -#else - " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -#endif - " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" - "}\n" - : - : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), - "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), - "r"(OOB_NAN_F16x8.w)); - - #else - - CUTLASS_UNUSED(smem_ptr); - CUTLASS_UNUSED(global_ptr); - CUTLASS_UNUSED(pred_guard); - CUTLASS_NOT_IMPLEMENTED(); - - #endif - } -}; - -/// Partial specialization to write one (1) -template -struct cp_async_diag { - using Element = Element_; - - CUTLASS_DEVICE - cp_async_diag(void *smem_ptr) { - #if CUDA_CP_ASYNC_ACTIVATED - - /// Values for the diagonal elements of the triangular input matrix - static __constant__ uint2 DIAG_DATA_DOUBLE_ONE = {0x3ff00000, 0x00000000}; - static __constant__ uint1 DIAG_DATA_FLOAT_ONE = {0x3f800000}; - static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - - if (platform::is_same>::value) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y), "r"(DIAG_DATA_DOUBLE_ONE.x), - "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); - } else if (platform::is_same>::value) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x), "r"(DIAG_DATA_ZERO.x)); - } else if (platform::is_same::value) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y),"r"(DIAG_DATA_DOUBLE_ONE.x)); - } else if (platform::is_same::value) { - asm volatile("st.shared.u32 [%0], %1;\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x)); - } else { - CUTLASS_UNUSED(smem_int_ptr); - CUTLASS_NOT_IMPLEMENTED(); - } - - #else - - CUTLASS_UNUSED(smem_ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif - } -}; - -/// Partial specialization to write zero for the imaginary part of Hermitian data -template -struct cp_async_diag { - using Element = Element_; - - CUTLASS_DEVICE - cp_async_diag(void *smem_ptr) { - #if CUDA_CP_ASYNC_ACTIVATED - - /// Values for the diagonal elements of the triangular input matrix - static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - - if (platform::is_same>::value) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); - } else if (platform::is_same>::value) { - asm volatile("st.shared.u32 [%0], %1;\n" - : : - "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x)); - } else { - CUTLASS_UNUSED(smem_int_ptr); - CUTLASS_NOT_IMPLEMENTED(); - } - - #else - - CUTLASS_UNUSED(smem_ptr); - CUTLASS_NOT_IMPLEMENTED(); - - #endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization -template < - /// Size of the access in bytes - int SizeInBytes> -struct cp_async { - - /// Copy - CUTLASS_DEVICE - cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - #if CUDA_CP_ASYNC_ACTIVATED - - static_assert(SizeInBytes == 16, - "cp.async only supports CacheOperation::Global when access size is 16B."); - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - cutlass::arch::synclog_emit_cp_async(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -#else - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -#endif - "}\n" ::"r"((int)pred_guard), - "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); - - #else - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - #endif - } -}; - -/// Partial specialization -template < - /// Size of the access in bytes - int SizeInBytes> -struct cp_async_zfill { - - /// Copy with zero fill - CUTLASS_DEVICE - cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - #if CUDA_CP_ASYNC_ACTIVATED - - static_assert(SizeInBytes == 16, - "cp.async only supports CacheOperation::Global when access size is 16B."); - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - int src_in_bytes = (pred_guard ? SizeInBytes : 0); - cutlass::arch::synclog_emit_cp_async_zfill(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); - - asm volatile( -#if CUTLASS_ENABLE_L2_PREFETCH - "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -#else - "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -#endif - "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); - - #else - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } - #endif - } -}; - -/// Partial specialization -template <> -struct cp_async_nan<16, CacheOperation::Global> { - static int const kSizeInBytes = 16; - - /// Copy with nan fill - CUTLASS_DEVICE - cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { - #if CUDA_CP_ASYNC_ACTIVATED - - static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, - OOB_NAN_F16x2, OOB_NAN_F16x2}; - - unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - cutlass::arch::synclog_emit_cp_async_nan(__LINE__, smem_int_ptr, global_ptr, pred_guard); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" -#if CUTLASS_ENABLE_L2_PREFETCH - " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -#else - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -#endif - " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" - "}\n" - : - : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), - "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), - "r"(OOB_NAN_F16x8.w)); - - #else - - CUTLASS_UNUSED(smem_ptr); - CUTLASS_UNUSED(global_ptr); - CUTLASS_UNUSED(pred_guard); - CUTLASS_NOT_IMPLEMENTED(); - - #endif - } -}; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -CUTLASS_DEVICE -void cp_async_fence() { - #if CUDA_CP_ASYNC_ACTIVATED - asm volatile("cp.async.commit_group;\n" ::); - cutlass::arch::synclog_emit_cp_async_fence(__LINE__); - #endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Blocks until all but previous cp.async.commit_group operations have committed. -template -CUTLASS_DEVICE void cp_async_wait() { - #if CUDA_CP_ASYNC_ACTIVATED - asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); - cutlass::arch::synclog_emit_cp_async_wait(__LINE__, N); - #endif -} - -/// Blocks until all previous cp.async.commit_group operations have committed. -template <> -CUTLASS_DEVICE void cp_async_wait<0>() { - #if CUDA_CP_ASYNC_ACTIVATED - asm volatile("cp.async.wait_all;\n" ::); - cutlass::arch::synclog_emit_cp_async_wait_all(__LINE__); - #endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma.h deleted file mode 100644 index 40c8200f2805c1d7a94774377ba90425d0d988cd..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma.h +++ /dev/null @@ -1,276 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/arch.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the operation implied by MMA. -struct OpMultiplyAdd {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT -struct OpMultiplyAddSaturate {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the input is converted to a narrower type (BF16) -struct OpMultiplyAddFastBF16 {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the input is converted to a narrower type (F16) -struct OpMultiplyAddFastF16 {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the input data types are mixed and the narrower type is -/// upcasted to the wider type -struct OpMultiplyAddMixedInputUpcast {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag indicating the input is converted to 2 (big and small) TF32 or FP16 components -// Perform 3xTF32 or 4xTF32 for every F32 output element on Ampere -// Perform 3xFP16 or 4xFP16 for every F32 output element on Hopper with axiswise quantization factor support -struct OpMultiplyAddFastF32 {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag indicating the input is converted to 2 (big and small) TF32 or FP16 components -// Perform 3xTF32 or 4xTF32 for every complex output element on Ampere -// Perform 3xFP16 or 4xFP16 for every complex output element on Hopper with axiswise quantization factor support -struct OpMultiplyAddComplexFastF32 {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating that staged accumulation is not to be used. This is valid only for SM89 -/// FP8 kernels. -struct OpMultiplyAddFastAccum; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the complex multiply-add operation -struct OpMultiplyAddComplex {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the gaussian complex multiply-add operation -struct OpMultiplyAddGaussianComplex {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the inner product is defined by (XOR, POPC) -struct OpXorPopc {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag indicating the inner product is defined by (AND, POPC) -struct OpAndPopc {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag classifying math operators as thread-level operations. -struct OpClassSimt {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag classifying operators as Tensor Core operations. -struct OpClassTensorOp {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag classifying operators as WMMA Tensor Core operations -struct OpClassWmmaTensorOp {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag classifying operators as Tensor Core with structure sparse operations. -struct OpClassSparseTensorOp {}; - - -/// Tag classifying operators as Tensor Core with blockScaled -struct OpClassBlockScaledTensorOp {}; - -/// Tag classifying operators as Tensor Core with blockScaled structured sparse operations. -struct OpClassBlockScaledSparseTensorOp {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Size of the matrix product (concept: GemmShape) - typename Shape_, - /// Number of threads participating - int kThreads_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Inner product operator - typename Operator -> -struct Mma; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation -template < - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Inner product operator - typename Operator_ -> -struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = Operator_; - using ElementC = ElementC_; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - - multiply_add op; - - d[0] = op(a[0], b[0], c[0]); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specifies internal data type for computation -struct SPFormatType { - enum Kind { - Thread - }; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Size of the matrix product (concept: GemmShape) - typename Shape_, - /// Number of threads participating - int kThreads_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Inner product operator - typename Operator, - /// Specifies meta data format - SPFormatType::Kind SPFormat = SPFormatType::Thread -> -struct SparseMma; - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Specializations for each compute capability -// - -#include "cutlass/arch/mma_sm50.h" -#include "cutlass/arch/mma_sm60.h" -#include "cutlass/arch/mma_sm61.h" -#include "cutlass/arch/mma_sm70.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" -#include "cutlass/arch/mma_sparse_sm80.h" -#include "cutlass/arch/mma_sm89.h" -#include "cutlass/arch/mma_sparse_sm89.h" -#include "cutlass/arch/mma_sm90.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { -namespace detail { -/// Helper for determining whether staged accumulation should be used for a given operator -template -struct UseStagedAccumulation { - static bool const value = platform::is_same::value || - platform::is_same::value || - is_sm89_staged_policy_v; -}; -} // namespace detail -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm100.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm100.h deleted file mode 100644 index 2863f2d2c4dfdad667db8a99b757c6bc94588b53..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm100.h +++ /dev/null @@ -1,118 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "cutlass/arch/mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/config.h" -#include "cute/arch/simd_sm100.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass{ -namespace arch { - - -/// Matrix multiply-add operation -template < - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC -> -struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = ElementC_; - - CUTLASS_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - d[i] = a[i] * b[0] + c[i]; - } - } -}; - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = float; - - CUTLASS_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - float2 result; - cute::fma(result, make_float2(a[0], a[1]), make_float2(b[0], b[0]), make_float2(c[0], c[1])); - d[0] = result.x; - d[1] = result.y; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h deleted file mode 100644 index 1701158b0bdd479cb179e4d0162c78ab335aba8a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h +++ /dev/null @@ -1,432 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once - -#include "cutlass/arch/mma.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/functional.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/gemm.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = float; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - d[0] = a[0] * b[0] + c[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = double; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - - d[0] = a[0] * b[0] + c[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = int; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - - d[0] = a[0] * b[0] + c[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0].real() * b[0].real() + c[0].real(); - d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); - d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); - d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - complex, - LayoutA, - float, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0].real() * b[0] + c[0].real(); - d[0].imag() = a[0].imag() * b[0] + c[0].imag(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - float, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0] * b[0].real() + c[0].real(); - d[0].imag() = a[0] * b[0].imag() + d[0].imag(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0].real() * b[0].real() + c[0].real(); - d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); - d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); - d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); - } -}; - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - complex, - LayoutA, - double, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0].real() * b[0] + c[0].real(); - d[0].imag() = a[0].imag() * b[0] + c[0].imag(); - } -}; - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma< - gemm::GemmShape<1, 1, 1>, - 1, - double, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAddComplex; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0] * b[0].real() + c[0].real(); - d[0].imag() = a[0] * b[0].imag() + d[0].imag(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = float; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - d[0] = float(a[0]) * float(b[0]) + c[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation for Quaternions -template < - /// Layout of A matrix - typename LayoutA, - /// Layout of B matrix - typename LayoutB, - /// Layout of C matrix - typename LayoutC -> -struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 1>; - using Operator = OpMultiplyAdd; - using Element = Quaternion; - using ElementC = Element; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - multiply_add op; - d[0] = op(a[0], b[0], c[0]); - } - -}; - -} -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h deleted file mode 100644 index 31ef2b653076863cfb9387ba078d31ee8b52d607..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h +++ /dev/null @@ -1,252 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once - -#include - -#include "cutlass/arch/mma.h" - -#include "cutlass/layout/matrix.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template -struct Mma< - gemm::GemmShape<2,1,1>, - 1, - half_t, - LayoutA, - half_t, - LayoutB, - half_t, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 1, 1>; - using Operator = OpMultiplyAdd; - using ElementC = half_t; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) - - __half2 const & A = reinterpret_cast<__half2 const &>(a); - __half2 B = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 const & C = reinterpret_cast<__half2 const &>(c); - - __half2 D = __hfma2(A, B, C); - - d = reinterpret_cast &>(D); - -#else - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - d[i] = a[i] * b[0] + c[i]; - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template -struct Mma< - gemm::GemmShape<1,2,1>, - 1, - half_t, - LayoutA, - half_t, - LayoutB, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 2, 1>; - using Operator = OpMultiplyAdd; - using ElementC = half_t; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) - - __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a)); - __half2 B = reinterpret_cast<__half2 const &>(b); - __half2 const & C = reinterpret_cast<__half2 const &>(c); - - __half2 D = __hfma2(A, B, C); - - d = reinterpret_cast &>(D); - -#else - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - d[i] = a[0] * b[i] + c[i]; - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template <> -struct Mma < - gemm::GemmShape<2, 2, 1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 2, 1>; - using Operator = OpMultiplyAdd; - using ElementC = half_t; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) - - __half2 const & A = reinterpret_cast<__half2 const &>(a); - __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b)); - __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b)); - - __half2 const *C = reinterpret_cast<__half2 const *>(&c); - - __half2 Dlo = __hfma2(A, Blo, C[0]); - __half2 Dhi = __hfma2(A, Bhi, C[1]); - - Array * D = reinterpret_cast *>(&d); - - D[0] = reinterpret_cast const &>(Dlo); - D[1] = reinterpret_cast const &>(Dhi); - -#else - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < 2; ++j) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j]; - } - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template <> -struct Mma< - gemm::GemmShape<2, 2, 1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 2, 1>; - using Operator = OpMultiplyAdd; - using ElementC = half_t; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) - - __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); - __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); - __half2 const & B = reinterpret_cast<__half2 const &>(b); - - __half2 const *C = reinterpret_cast<__half2 const *>(&c); - - __half2 Dlo = __hfma2(Alo, B, C[0]); - __half2 Dhi = __hfma2(Ahi, B, C[1]); - - Array * D = reinterpret_cast *>(&d); - - D[0] = reinterpret_cast &>(Dlo); - D[1] = reinterpret_cast &>(Dhi); -#else - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < 2; ++j) { - d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j]; - } - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h deleted file mode 100644 index b780335efadeecee07f7c1c98422f18fec6f7ea3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h +++ /dev/null @@ -1,142 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once - -#include "cutlass/layout/matrix.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template -struct Mma< - gemm::GemmShape<1,1,4>, - 1, - int8_t, - LayoutA, - int8_t, - LayoutB, - int, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 4>; - using Operator = OpMultiplyAdd; - using ElementC = int; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) - - unsigned const &A = reinterpret_cast(a); - unsigned const &B = reinterpret_cast(b); - - asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" - : "=r"(d[0]) - : "r"(A), "r"(B), "r"(c[0])); - -#else - - d[0] = c[0]; - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < 4; ++k) { - d[0] += a[k] * b[k]; - } - -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template -struct Mma< - gemm::GemmShape<1, 1, 2>, - 1, - int16_t, - layout::RowMajor, - int16_t, - layout::ColumnMajor, - int, - LayoutC, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<1, 1, 2>; - using Operator = OpMultiplyAdd; - using ElementC = int; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) - - unsigned const &A = reinterpret_cast(a); - unsigned const &B = reinterpret_cast(b); - - asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" - : "=r"(d[0]) - : "r"(A), "r"(B), "r"(c[0])); -#else - d[0] = c[0]; - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < 2; ++k) { - d[0] += a[k] * b[k]; - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h deleted file mode 100644 index 6acdcfac3b9d3d10253d3a343a1d097b617ddb16..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h +++ /dev/null @@ -1,661 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) -#define CUTLASS_ARCH_MMA_SM70_SUPPORTED -#endif - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) - -#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) -#define CUTLASS_ARCH_MMA_SM70_ENABLED -#endif - -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Matrix multiply accumulate 884 - FP16 accumulation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<8,8,4>, - 8, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::ColumnMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - unsigned const *C = reinterpret_cast(&c); - unsigned *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::ColumnMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::RowMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - unsigned const *C = reinterpret_cast(&c); - unsigned *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - unsigned const *C = reinterpret_cast(&c); - unsigned *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::RowMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - unsigned const *C = reinterpret_cast(&c); - unsigned *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Matrix multiply accumulate 884 - FP32 accumulation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::ColumnMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - /// Multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), - "=f"(D[1]), - "=f"(D[2]), - "=f"(D[3]), - "=f"(D[4]), - "=f"(D[5]), - "=f"(D[6]), - "=f"(D[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::ColumnMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::RowMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - /// Multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), - "=f"(D[1]), - "=f"(D[2]), - "=f"(D[3]), - "=f"(D[4]), - "=f"(D[5]), - "=f"(D[6]), - "=f"(D[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - /// Multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), - "=f"(D[1]), - "=f"(D[2]), - "=f"(D[3]), - "=f"(D[4]), - "=f"(D[5]), - "=f"(D[6]), - "=f"(D[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 4>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::RowMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm70; - - /// Multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) { - -#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), - "=f"(D[1]), - "=f"(D[2]), - "=f"(D[3]), - "=f"(D[4]), - "=f"(D[5]), - "=f"(D[6]), - "=f"(D[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7]) - ); - -#else - assert(0); - #if defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); - #endif -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation specialized for the entire warp -template < - typename LayoutA, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename Operator -> -struct Mma< - gemm::GemmShape<16, 16, 4>, - 32, - half_t, - LayoutA, - half_t, - LayoutB, - ElementC, - LayoutC, - Operator -> : - public Mma< - gemm::GemmShape<8, 8, 4>, - 8, - half_t, - LayoutA, - half_t, - LayoutB, - ElementC, - LayoutC, - Operator> { - - using Shape = gemm::GemmShape<16, 16, 4>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h deleted file mode 100644 index c71ea076b5c2390cea8b0ba17ae1b642c5d49b48..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h +++ /dev/null @@ -1,789 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply for SM75 -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply. -#include -#include "cutlass/wmma_array.h" -#endif - -// CUTLASS includes -#include "cutlass/arch/mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) - -#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -#define CUTLASS_ARCH_MMA_SM75_ENABLED -#endif -#endif - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 1688 - FP16 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation - F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<16, 8, 8>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 8>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - unsigned const *C = reinterpret_cast(&c); - unsigned *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 1688 - FP32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 8>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 8>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const *A = reinterpret_cast(&a); - unsigned const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), - "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply (8b) with SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply (4b) - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - int4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - uint4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - int4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - uint4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// b1 ^ b1 + s32 => s32 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation -template <> -struct Mma< - gemm::GemmShape<8,8,128>, - 32, - uint1b_t, - layout::RowMajor, - uint1b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpXorPopc> { - - using Shape = gemm::GemmShape<8,8,128>; - - using ElementA = uint1b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint1b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpXorPopc; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - using WmmaFragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - nvcuda::wmma::experimental::precision::b1, - nvcuda::wmma::row_major>; - - using WmmaFragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - nvcuda::wmma::experimental::precision::b1, - nvcuda::wmma::col_major>; - - using WmmaFragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - int>; - - WmmaFragmentA const & A = reinterpret_cast(a); - WmmaFragmentB const & B = reinterpret_cast(b); - - WmmaFragmentC const & C = reinterpret_cast(c); - WmmaFragmentC & D = reinterpret_cast(d); - - nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, - nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. - -#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h deleted file mode 100644 index 22cd87d65b0412e9ac9a4953feee022c5e5feb92..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h +++ /dev/null @@ -1,1500 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) - -#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -#define CUTLASS_ARCH_MMA_SM80_ENABLED - -#if (__CUDA_ARCH__ <= 900) -#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED -#endif -#if (__CUDA_ARCH__ <= 890) -#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED -#endif - -#endif - -#endif - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 1688 - Float BF16, FP32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 8>, - 32, - bfloat16_t, - layout::RowMajor, - bfloat16_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 8>; - - using ElementA = bfloat16_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = bfloat16_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), - "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 1684 - Float TF32 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 4>, - 32, - tfloat32_t, - layout::RowMajor, - tfloat32_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 4>; - - using ElementA = tfloat32_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = tfloat32_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), - "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 1688 - Float TF32 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -template <> -struct Mma, 32, tfloat32_t, layout::RowMajor, - tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, - OpMultiplyAdd> { - using Shape = gemm::GemmShape<16, 8, 8>; - - using ElementA = tfloat32_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = tfloat32_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16816 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct Mma< - gemm::GemmShape<16, 8, 16>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 16>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 16>, - 32, - bfloat16_t, - layout::RowMajor, - bfloat16_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 16>; - - using ElementA = bfloat16_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = bfloat16_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 16>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 16>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 884 - F64 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -template <> -struct Mma< - gemm::GemmShape<8,8,4>, - 32, - double, - layout::RowMajor, - double, - layout::ColumnMajor, - double, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8,8,4>; - - using ElementA = double; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = double; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - - using ArchTag = arch::Sm80; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - double const & A = reinterpret_cast(a); - double const & B = reinterpret_cast(b); - - double const *C = reinterpret_cast(&c); - double *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=d"(D[0]), "=d"(D[1]) - : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const * A = reinterpret_cast(&a); - uint32_t const * B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const * A = reinterpret_cast(&a); - uint32_t const * B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,256>, - 32, - cutlass::uint1b_t, - layout::RowMajor, - cutlass::uint1b_t, - layout::ColumnMajor, - int32_t, - layout::RowMajor, - OpAndPopc> { - - using Shape = gemm::GemmShape<16,8,256>; - - using ElementA = cutlass::uint1b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint1b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int32_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpAndPopc; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,256>, - 32, - cutlass::uint1b_t, - layout::RowMajor, - cutlass::uint1b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,256>; - - using ElementA = cutlass::uint1b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint1b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int32_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,256>, - 32, - cutlass::uint1b_t, - layout::RowMajor, - cutlass::uint1b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpXorPopc> { - - using Shape = gemm::GemmShape<16,8,256>; - - using ElementA = cutlass::uint1b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint1b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpXorPopc; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); - -#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h deleted file mode 100644 index 4bcd9bc1de9b6e53629e08f478a50d791d198a1a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h +++ /dev/null @@ -1,641 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Matrix multiply-accumulate specialzied for SM89 -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) -# define CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED -#endif - -#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) -# define CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED -#endif - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) -# if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) -# define CUTLASS_ARCH_MMA_F32_SM89_ENABLED -# endif - -# if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) -# define CUTLASS_ARCH_MMA_F16_SM89_ENABLED -# endif -#endif - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -// Whether the Mma uses as SM89 staged accumulation policy -template -static constexpr bool is_sm89_staged_policy_v = - ( - // ElementA must be FP8 - platform::is_same::value || - platform::is_same::value - ) && - ( - // ElementB must be FP8 - platform::is_same::value || - platform::is_same::value - ) && - ( - // The instruction shape must be 16x8x32 - Operator::ArchMmaOperator::Shape::kM == 16 && - Operator::ArchMmaOperator::Shape::kN == 8 && - Operator::ArchMmaOperator::Shape::kK == 32 - ) && - ( - // The operator must be OpMultiplyAdd (default) - platform::is_same::value - ); -} // namespace detail - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP16 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation - F16 = fe4m3 * fe4m3 + F16 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - cutlass::half_t, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = cutlass::half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F16 = fe4m3 * fe5m2 + F16 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - cutlass::half_t, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = cutlass::half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F16 = fe5m2 * fe4m3 + F16 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - cutlass::half_t, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = cutlass::half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -/// Matrix multiply-add operation - F16 = fe5m2 * fe5m2 + F16 -template -struct Mma< - gemm::GemmShape<16, 8, 32>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - cutlass::half_t, - layout::RowMajor, - Operator_> { - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = cutlass::half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - - asm( - "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : - "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]) - ); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); - -#endif - } -}; - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h deleted file mode 100644 index b135c8645b48eb40a1cce88c515074e95d4b6a5e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h +++ /dev/null @@ -1,241 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/config.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// -/// Matrix Multiply-Add 16x8x4 fp64 -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -template <> -struct Mma< - gemm::GemmShape<16,8,4>, - 32, - double, - layout::RowMajor, - double, - layout::ColumnMajor, - double, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,4>; - - using ElementA = double; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = double; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - - using ArchTag = arch::Sm90; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - - double const *A = reinterpret_cast(&a); - double const *B = reinterpret_cast(&b); - - double const *C = reinterpret_cast(&c); - double *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) - : "d"(A[0]), "d"(A[1]), - "d"(B[0]), - "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); - -#else - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Matrix Multiply-Add 16x8x8 fp64 -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -template <> -struct Mma< - gemm::GemmShape<16,8,8>, - 32, - double, - layout::RowMajor, - double, - layout::ColumnMajor, - double, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,8>; - - using ElementA = double; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = double; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - - using ArchTag = arch::Sm90; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - - double const *A = reinterpret_cast(&a); - double const *B = reinterpret_cast(&b); - - double const *C = reinterpret_cast(&c); - double *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) - : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), - "d"(B[0]), "d"(B[1]), - "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); - -#else - - CUTLASS_UNUSED(d); - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Matrix Multiply-Add 16x8x16 fp64 -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - double, - layout::RowMajor, - double, - layout::ColumnMajor, - double, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = double; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = double; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - - using ArchTag = arch::Sm90; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c) const { - -#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - - double const *A = reinterpret_cast(&a); - double const *B = reinterpret_cast(&b); - - double const *C = reinterpret_cast(&c); - double *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" - : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) - : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), - "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), - "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); - -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h deleted file mode 100644 index e4ca91a10293334fbd89e21891132442a6216e6a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h +++ /dev/null @@ -1,1234 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Sparse matrix multiply accumulate for SM80 -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) - -#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED -#endif - -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 16832 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -template <> -struct SparseMma< - gemm::GemmShape<16, 8, 32>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread -> { - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 2; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c, uint32_t const &E, int const id2) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - uint32_t const *C = reinterpret_cast(&c); - uint32_t *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " - "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); - } - else if (id2 == 1) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " - "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); - } - else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " - "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); - } - else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " - "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); - } - else { - assert(0); - } -#endif - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -template <> -struct SparseMma< - gemm::GemmShape<16, 8, 32>, - 32, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - float, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread - > { - - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = half_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = half_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 2; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c, uint32_t const &E, int const id2) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "r"(E)); - } - else if (id2 == 1) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "r"(E)); - } - else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "r"(E)); - } - else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "r"(E)); - } - else { - assert(0); - } - -#endif - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -template <> -struct SparseMma, 32, bfloat16_t, layout::RowMajor, - bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, - OpMultiplyAdd, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16, 8, 32>; - - using ElementA = bfloat16_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = bfloat16_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 2; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c, uint32_t const &E, int const id2) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else if (id2 == 1) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 16816 - Float TF32 -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -template <> -struct SparseMma, 32, tfloat32_t, layout::RowMajor, - tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, - OpMultiplyAdd, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16, 8, 16>; - - using ElementA = tfloat32_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = tfloat32_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 4; - - static int const kMaxID2 = 2; - - CUTLASS_HOST_DEVICE - void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, - FragmentC const &c, uint32_t const &E, int const id2) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else if (id2 == 1) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - -#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) - if (id2 == 0) { - asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#else - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - } else { - assert(0); - } -#endif - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h deleted file mode 100644 index 6adca25527efdc1c3cb564b4553d96bebe59b3fd..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h +++ /dev/null @@ -1,406 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Sparse matrix multiply accumulate for SM89 -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) - -#include "mma.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) -# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED -#endif - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) -# if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED) -# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED -# endif -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32 -template -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_, - SPFormatType::Thread> { - - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } - else { - assert(0); - } -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32 -template -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - cutlass::float_e4m3_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_, - SPFormatType::Thread> { - - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } - else { - assert(0); - } -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32 -template -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e4m3_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_, - SPFormatType::Thread> { - - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } - else { - assert(0); - } -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32 -template -struct SparseMma< - gemm::GemmShape<16,8,64>, - 32, - cutlass::float_e5m2_t, - layout::RowMajor, - cutlass::float_e5m2_t, - layout::ColumnMajor, - float, - layout::RowMajor, - Operator_, - SPFormatType::Thread> { - - static_assert(platform::is_same::value || - platform::is_same::value, - "Invalid operator for SM89 FP8 instruction"); - - using Shape = gemm::GemmShape<16,8,64>; - - using ElementA = cutlass::float_e5m2_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::float_e5m2_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = float; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = Operator_; - using ArchTag = arch::Sm89; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - float const *C = reinterpret_cast(&c); - float *D = reinterpret_cast(&d); - - if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); - } - else { - assert(0); - } -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h deleted file mode 100644 index 93dd37d3193867602d69866cb2cfcd2e27e87f62..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h +++ /dev/null @@ -1,89 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief PTX for CTA Reconfiguration -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - -#ifndef CUDA_CTA_RECONFIG_ACTIVATED - #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ - (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \ - || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \ - || (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \ - || (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \ - || (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \ - || (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \ - ) - #define CUDA_CTA_RECONFIG_ACTIVATED 1 - #endif - - #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ - (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ - || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ - || (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \ - || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ - || (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \ - ) - #define CUDA_CTA_RECONFIG_ACTIVATED 1 - #endif - -#endif - -namespace cutlass { -namespace arch { - -template -CUTLASS_DEVICE -void warpgroup_reg_alloc(){ -#if CUDA_CTA_RECONFIG_ACTIVATED - asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -#endif -} - -template -CUTLASS_DEVICE -void warpgroup_reg_dealloc(){ -#if CUDA_CTA_RECONFIG_ACTIVATED - asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -#endif -} - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h deleted file mode 100644 index a1dc7dff4d603ecf7e6a190c84bc7634e8c8be62..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h +++ /dev/null @@ -1,125 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing SIMD operators -*/ - -#pragma once - -#include "cutlass/arch/array.h" -#include "cutlass/arch/numeric_types.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Element-wise operators -// - -CUTLASS_HOST_DEVICE -template -Array operator*(Array const &a, Array const &b) { - Array d; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - d[i] = a[i] * b[i]; - } - return d; -} - -CUTLASS_HOST_DEVICE -template -Array operator+(Array const &a, Array const &b) { - Array d; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - d[i] = a[i] + b[i]; - } - return d; -} - -CUTLASS_HOST_DEVICE -template -Array operator-(Array const &a, Array const &b) { - Array d; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - d[i] = a[i] - b[i]; - } - return d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Multiply-accumulate operators -// - -CUTLASS_HOST_DEVICE -template -Array mac(Array const &a, Array const &b, Array const &c) { - Array d; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - d[i] = a[i] * b[i] + c[i]; - } - return d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Dot product operator -// - -CUTLASS_HOST_DEVICE -template -Accumulator dot(Array const &a, Array const &b, Accumulator accum) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - accum += a[i] * b[i]; - } - return accum; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "simd_sm60.h" -#include "simd_sm61.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h deleted file mode 100644 index 59f38d62da91ab9af6a1f73a5990d29056dd259a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h +++ /dev/null @@ -1,104 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing SIMD operators for SM60 -*/ - -#pragma once - -#include "simd.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Element-wise operators - specialized for half_t x 2 -// - -CUTLASS_HOST_DEVICE -template <> -Array operator*(Array const &a, Array const &b) { - Array d; - - return d; -} - -CUTLASS_HOST_DEVICE -template <> -Array operator+(AArray const &a, Array const &b) { - Array d; - - return d; -} - -CUTLASS_HOST_DEVICE -template <> -Array operator-(Array const &a, Array const &b) { - Array d; - - return d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Multiply-accumulate operators - specialized for half_t x 2 -CUTLASS_HOST_DEVICE -template <> -Array mac(Array const &a, Array const &b, Array const &c) { - Array d; - - return d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t -CUTLASS_HOST_DEVICE -template <> -half_t dot(Array const &a, Array const &b, half_t accum) { - - return accum; -} - -/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float -CUTLASS_HOST_DEVICE -template <> -float dot(Array const &a, Array const &b, float accum) { - - return accum; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h deleted file mode 100644 index 46c22665c2126b5dd2e0fb143be00143b933f3ec..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h +++ /dev/null @@ -1,147 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing SIMD operators for SM61 -*/ - -#pragma once - -#include "simd.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -CUTLASS_HOST_DEVICE -template <> -int32_t dot(Array const &a, Array const &b, int32_t accum) { - - return accum; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp deleted file mode 100644 index 5567fe561f8ca7a95f7b0958aaced2696109f22a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp +++ /dev/null @@ -1,1271 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Synchronization event logging for race condition debugging. -*/ - -#pragma once - -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#endif - -#if !defined(__CUDACC_RTC__) -#include -#include -#endif - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTLASS_ENABLE_SYNCLOG) - -constexpr uint32_t synclog_cap = 1 << 26; - -inline std::mutex synclog_mutex; -inline std::vector synclog_buf_list; -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -CUTLASS_DEVICE uint32_t* synclog_buf; -#endif - -CUTLASS_DEVICE -uint32_t* synclog_alloc(uint32_t n) { - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - uint32_t* buf = synclog_buf; - if (buf == nullptr) return nullptr; - uint32_t last = atomicAdd(&buf[0], n); - if (last + n < synclog_cap) return buf + last + 1; - if (last >= synclog_cap) atomicAdd(&buf[0], -n); - #endif - return nullptr; -} - -CUTLASS_DEVICE -void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - uint64_t time64; - asm volatile ( - "mov.u64 %0, %%globaltimer;\n" - : "=l"(time64) : - ); - to[0] = header; - to[1] = line; - to[2] = time64; - to[3] = time64 >> 32; - to[4] = threadIdx.x; - to[5] = threadIdx.y; - to[6] = threadIdx.z; - to[7] = blockIdx.x; - to[8] = blockIdx.y; - to[9] = blockIdx.z; - #endif -} - -constexpr uint32_t synclog_header_none = 0; -constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; - -constexpr bool synclog_enable_syncthreads = true; -constexpr uint32_t synclog_header_syncthreads = 1; -constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; - -constexpr bool synclog_enable_syncwarp = true; -constexpr uint32_t synclog_header_syncwarp = 2; -constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; - -constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; -constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; -constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; - -constexpr bool synclog_enable_named_barrier_arrive = true; -constexpr uint32_t synclog_header_named_barrier_arrive = 4; -constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; - -constexpr bool synclog_enable_cluster_barrier_init = true; -constexpr uint32_t synclog_header_cluster_barrier_init = 5; -constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; - -constexpr bool synclog_enable_cluster_barrier_wait = true; -constexpr uint32_t synclog_header_cluster_barrier_wait = 6; -constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2; -constexpr bool synclog_enable_cluster_barrier_test_wait = true; -constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; -constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3; -constexpr bool synclog_enable_cluster_barrier_try_wait = true; -constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; -constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2; -constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; -constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; -constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3; -constexpr bool synclog_enable_cluster_barrier_arrive = true; -constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; -constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1; -constexpr bool synclog_enable_cluster_barrier_invalidate = true; -constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; -constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1; -constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; -constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2; -constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; -constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4; -constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; -constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; -constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2; -constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; -constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; -constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4; -constexpr bool synclog_enable_fence_barrier_init = true; -constexpr uint32_t synclog_header_fence_barrier_init = 16; -constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; - -constexpr bool synclog_enable_fence_view_async_shared = true; -constexpr uint32_t synclog_header_fence_view_async_shared = 17; -constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; - -constexpr bool synclog_enable_cp_async_wait = true; -constexpr uint32_t synclog_header_cp_async_wait = 18; -constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; - -constexpr bool synclog_enable_cp_async_wait_all = true; -constexpr uint32_t synclog_header_cp_async_wait_all = 19; -constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; - -constexpr bool synclog_enable_cp_async_fence = true; -constexpr uint32_t synclog_header_cp_async_fence = 20; -constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; - -constexpr bool synclog_enable_cp_async_nan = true; -constexpr uint32_t synclog_header_cp_async_nan = 21; -constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; - -constexpr bool synclog_enable_cp_async_zfill = true; -constexpr uint32_t synclog_header_cp_async_zfill = 22; -constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; - -constexpr bool synclog_enable_cp_async = true; -constexpr uint32_t synclog_header_cp_async = 23; -constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; - -constexpr bool synclog_enable_tma_load = true; -constexpr uint32_t synclog_header_tma_load = 24; -constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; - -constexpr bool synclog_enable_tma_store = true; -constexpr uint32_t synclog_header_tma_store = 25; -constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; - -constexpr bool synclog_enable_tma_store_arrive = true; -constexpr uint32_t synclog_header_tma_store_arrive = 26; -constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; - -constexpr bool synclog_enable_tma_store_wait = true; -constexpr uint32_t synclog_header_tma_store_wait = 27; -constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; - -constexpr bool synclog_enable_warpgroup_arrive = true; -constexpr uint32_t synclog_header_warpgroup_arrive = 28; -constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; - -constexpr bool synclog_enable_warpgroup_wait = true; -constexpr uint32_t synclog_header_warpgroup_wait = 29; -constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; - -constexpr bool synclog_enable_warpgroup_commit_batch = true; -constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; -constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; - -constexpr bool synclog_enable_wgmma_reg_smem = true; -constexpr uint32_t synclog_header_wgmma_reg_smem = 31; -constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; - -constexpr bool synclog_enable_wgmma_smem_smem = true; -constexpr uint32_t synclog_header_wgmma_smem_smem = 32; -constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; - -constexpr bool synclog_enable_cpasync_barrier_arrive = true; -constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; -constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1; -CUTLASS_DEVICE -bool synclog_condition_emit() { - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && - blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; - #else - return 0; - #endif -} - -CUTLASS_DEVICE -bool synclog_condition_print() { - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && - blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; - #else - return false; - #endif -} - -CUTLASS_DEVICE -void synclog_print_prefix(char const* header, uint32_t at) { - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - uint32_t line = synclog_buf[at + 1]; - uint32_t timeLo = synclog_buf[at + 2]; - uint32_t timeHi = synclog_buf[at + 3]; - uint32_t threadIdxX = synclog_buf[at + 4]; - uint32_t threadIdxY = synclog_buf[at + 5]; - uint32_t threadIdxZ = synclog_buf[at + 6]; - uint32_t blockIdxX = synclog_buf[at + 7]; - uint32_t blockIdxY = synclog_buf[at + 8]; - uint32_t blockIdxZ = synclog_buf[at + 9]; - printf( - "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", - header, line, - (uint64_t)timeHi << 32 | timeLo, - threadIdxX, threadIdxY, threadIdxZ, - blockIdxX, blockIdxY, blockIdxZ - ); - #endif -} - -CUTLASS_DEVICE -void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { - CUTLASS_UNUSED(hi); - uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; - printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); -} - -#endif // defined(CUTLASS_ENABLE_SYNCLOG) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline void synclog_setup() { - #if defined(CUTLASS_ENABLE_SYNCLOG) - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - std::scoped_lock lock(synclog_mutex); - auto fail = [] () { - fprintf(stderr, "synclog_setup() failed\n"); - std::terminate(); - }; - int orig_device = 0; - if (cudaGetDevice(&orig_device) != cudaSuccess) { - fail(); - } - int device_count = 0; - if (cudaGetDeviceCount(&device_count) != cudaSuccess) { - fail(); - } - if (synclog_buf_list.size() == 0) { - for (int device = 0; device < device_count; device++) { - uint32_t* buf = 0; - if (cudaSetDevice(device) != cudaSuccess || - cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { - fail(); - } - synclog_buf_list.push_back(buf); - } - } - for (int device = 0; device < device_count; device++) { - uint32_t* buf = synclog_buf_list.at(device); - if (cudaSetDevice(device) != cudaSuccess || - cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || - cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { - fail(); - } - } - if (cudaSetDevice(orig_device) != cudaSuccess) { - fail(); - } - #endif - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_syncthreads(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_syncthreads) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_syncthreads); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_syncthreads, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_syncwarp(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_syncwarp) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_syncwarp); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_syncwarp, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_named_barrier_arrive_and_wait( - uint32_t line, - uint32_t num_threads, - uint32_t barrier_id) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); - to[synclog_length_prefix + 0] = num_threads; - to[synclog_length_prefix + 1] = barrier_id; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(num_threads); - CUTLASS_UNUSED(barrier_id); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_named_barrier_arrive( - uint32_t line, - uint32_t num_threads, - uint32_t barrier_id) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_named_barrier_arrive) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); - to[synclog_length_prefix + 0] = num_threads; - to[synclog_length_prefix + 1] = barrier_id; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(num_threads); - CUTLASS_UNUSED(barrier_id); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_init( - uint32_t line, - uint32_t smem_addr, - uint32_t arrive_count) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_init) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = arrive_count; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(arrive_count); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_wait( - uint32_t line, - uint32_t smem_addr, - uint32_t phase) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = phase; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(phase); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_test_wait( - uint32_t line, - uint32_t smem_addr, - uint32_t phase, - uint32_t pred) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_test_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = phase; - to[synclog_length_prefix + 2] = pred; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(phase); - CUTLASS_UNUSED(pred); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_try_wait( - uint32_t line, - uint32_t smem_addr, - uint32_t phase) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_try_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = phase; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(phase); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_arrive_cluster( - uint32_t line, - uint32_t smem_addr, - uint32_t cta_id, - uint32_t pred) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = cta_id; - to[synclog_length_prefix + 2] = pred; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(cta_id); - CUTLASS_UNUSED(pred); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_arrive( - uint32_t line, - uint32_t smem_addr) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_arrive) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); - to[synclog_length_prefix + 0] = smem_addr; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_barrier_invalidate( - uint32_t line, - uint32_t smem_addr) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_barrier_invalidate) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); - to[synclog_length_prefix + 0] = smem_addr; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( - uint32_t line, - uint32_t smem_addr, - uint32_t transaction_bytes) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = transaction_bytes; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(transaction_bytes); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( - uint32_t line, - uint32_t smem_addr, - uint32_t transaction_bytes, - uint32_t cta_id, - uint32_t pred) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = transaction_bytes; - to[synclog_length_prefix + 2] = cta_id; - to[synclog_length_prefix + 3] = pred; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(transaction_bytes); - CUTLASS_UNUSED(cta_id); - CUTLASS_UNUSED(pred); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_transaction_barrier_expect_transaction( - uint32_t line, - uint32_t smem_addr, - uint32_t transaction_bytes) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = transaction_bytes; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(transaction_bytes); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cluster_transaction_barrier_complete_transaction( - uint32_t line, - uint32_t smem_addr, - uint32_t dst_cta_id, - uint32_t transaction_bytes, - uint32_t pred) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = dst_cta_id; - to[synclog_length_prefix + 2] = transaction_bytes; - to[synclog_length_prefix + 3] = pred; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(dst_cta_id); - CUTLASS_UNUSED(transaction_bytes); - CUTLASS_UNUSED(pred); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_fence_barrier_init(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_fence_barrier_init) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_fence_view_async_shared(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_fence_view_async_shared) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async_wait( - uint32_t line, - uint32_t n) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async_wait, line); - to[synclog_length_prefix + 0] = n; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(n); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async_wait_all(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async_wait_all) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async_fence(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async_fence) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async_fence, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async_nan( - uint32_t line, - uint32_t smem_addr, - const void* gmem_ptr, - uint32_t pred) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async_nan) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async_nan, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); - to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); - to[synclog_length_prefix + 3] = pred; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(gmem_ptr); - CUTLASS_UNUSED(pred); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async_zfill( - uint32_t line, - uint32_t smem_addr, - const void* gmem_ptr, - uint32_t pred, - uint32_t size) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async_zfill) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); - to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); - to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = size; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(gmem_ptr); - CUTLASS_UNUSED(pred); - CUTLASS_UNUSED(size); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cp_async( - uint32_t line, - uint32_t smem_addr, - const void* gmem_ptr, - uint32_t pred, - uint32_t size) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cp_async) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cp_async); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cp_async, line); - to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); - to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); - to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = size; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - CUTLASS_UNUSED(gmem_ptr); - CUTLASS_UNUSED(pred); - CUTLASS_UNUSED(size); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_tma_load( - uint32_t line, - uint64_t gmem_int_desc, - uint32_t smem_int_mbar, - uint32_t smem_int_ptr) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_tma_load) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_tma_load); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_tma_load, line); - to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); - to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); - to[synclog_length_prefix + 2] = smem_int_mbar; - to[synclog_length_prefix + 3] = smem_int_ptr; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(gmem_int_desc); - CUTLASS_UNUSED(smem_int_mbar); - CUTLASS_UNUSED(smem_int_ptr); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_tma_store( - uint32_t line, - uint64_t gmem_int_desc, - uint32_t smem_int_ptr) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_tma_store) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_tma_store); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_tma_store, line); - to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); - to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); - to[synclog_length_prefix + 2] = smem_int_ptr; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(gmem_int_desc); - CUTLASS_UNUSED(smem_int_ptr); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_tma_store_arrive(uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_tma_store_arrive) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_tma_store_wait( - uint32_t line, - uint32_t count) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_tma_store_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_tma_store_wait, line); - to[synclog_length_prefix + 0] = count; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(count); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_warpgroup_arrive( - uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_warpgroup_arrive) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_warpgroup_wait( - uint32_t line, - uint32_t n) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_warpgroup_wait) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); - to[synclog_length_prefix + 0] = n; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(n); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_warpgroup_commit_batch( - uint32_t line) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_warpgroup_commit_batch) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); - #else - CUTLASS_UNUSED(line); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_wgmma_reg_smem( - uint32_t line, - uint64_t desc_b) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_wgmma_reg_smem) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); - to[synclog_length_prefix + 0] = desc_b; - to[synclog_length_prefix + 1] = desc_b >> 32; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(desc_b); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_wgmma_smem_smem( - uint32_t line, - uint64_t desc_a, - uint64_t desc_b) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_wgmma_smem_smem) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); - to[synclog_length_prefix + 0] = desc_a; - to[synclog_length_prefix + 1] = desc_a >> 32; - to[synclog_length_prefix + 2] = desc_b; - to[synclog_length_prefix + 3] = desc_b >> 32; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(desc_a); - CUTLASS_UNUSED(desc_b); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -CUTLASS_DEVICE -void synclog_emit_cpasync_barrier_arrive( - uint32_t line, - uint32_t smem_addr) { - #if defined(CUTLASS_ENABLE_SYNCLOG) - if constexpr (!synclog_enable_cpasync_barrier_arrive) return; - if (!synclog_condition_emit()) return; - uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); - if (to == nullptr) return; - synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); - to[synclog_length_prefix + 0] = smem_addr; - #else - CUTLASS_UNUSED(line); - CUTLASS_UNUSED(smem_addr); - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -#if !defined(CUTLASS_ENABLE_SYNCLOG) -CUTLASS_DEVICE -#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -static __attribute__((__noinline__)) __device__ -#else -static __attribute__((__noinline__)) -#endif -void synclog_print() { - #if defined(CUTLASS_ENABLE_SYNCLOG) - #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - if (synclog_buf == nullptr || !synclog_condition_print()) { - return; - } - printf("synclog start\n"); - for (uint32_t at = 1; at < synclog_cap; ) { - uint32_t header = synclog_buf[at]; - if (header == synclog_header_none) { - break; - } - printf("synclog at %u: ", at); - if constexpr (synclog_enable_syncthreads) { - if (header == synclog_header_syncthreads) { - synclog_print_prefix("syncthreads", at); - at += synclog_length_syncthreads; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_syncwarp) { - if (header == synclog_header_syncwarp) { - synclog_print_prefix("syncwarp", at); - at += synclog_length_syncwarp; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_named_barrier_arrive_and_wait) { - if (header == synclog_header_named_barrier_arrive_and_wait) { - synclog_print_prefix("named_barrier_arrive_and_wait", at); - at += synclog_length_named_barrier_arrive_and_wait; - printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_named_barrier_arrive) { - if (header == synclog_header_named_barrier_arrive) { - synclog_print_prefix("named_barrier_arrive", at); - at += synclog_length_named_barrier_arrive; - printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_init) { - if (header == synclog_header_cluster_barrier_init) { - synclog_print_prefix("cluster_barrier_init", at); - at += synclog_length_cluster_barrier_init; - printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_wait) { - if (header == synclog_header_cluster_barrier_wait) { - synclog_print_prefix("cluster_barrier_wait", at); - at += synclog_length_cluster_barrier_wait; - printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_test_wait) { - if (header == synclog_header_cluster_barrier_test_wait) { - synclog_print_prefix("cluster_barrier_test_wait", at); - at += synclog_length_cluster_barrier_test_wait; - printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_try_wait) { - if (header == synclog_header_cluster_barrier_try_wait) { - synclog_print_prefix("cluster_barrier_try_wait", at); - at += synclog_length_cluster_barrier_try_wait; - printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { - if (header == synclog_header_cluster_barrier_arrive_cluster) { - synclog_print_prefix("cluster_barrier_arrive_cluster", at); - at += synclog_length_cluster_barrier_arrive_cluster; - printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_arrive) { - if (header == synclog_header_cluster_barrier_arrive) { - synclog_print_prefix("cluster_barrier_arrive", at); - at += synclog_length_cluster_barrier_arrive; - printf("smem_addr=%u\n", synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_barrier_invalidate) { - if (header == synclog_header_cluster_barrier_invalidate) { - synclog_print_prefix("cluster_barrier_invalidate", at); - at += synclog_length_cluster_barrier_invalidate; - printf("smem_addr=%u\n", synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { - if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { - synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); - at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; - printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { - if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { - synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); - at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; - printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { - if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { - synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); - at += synclog_length_cluster_transaction_barrier_expect_transaction; - printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { - if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { - synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); - at += synclog_length_cluster_transaction_barrier_complete_transaction; - printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_fence_barrier_init) { - if (header == synclog_header_fence_barrier_init) { - synclog_print_prefix("fence_barrier_init", at); - at += synclog_length_fence_barrier_init; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_fence_view_async_shared) { - if (header == synclog_header_fence_view_async_shared) { - synclog_print_prefix("fence_view_async_shared", at); - at += synclog_length_fence_view_async_shared; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_cp_async_wait) { - if (header == synclog_header_cp_async_wait) { - synclog_print_prefix("cp_async_wait", at); - at += synclog_length_cp_async_wait; - printf("n=%u\n", synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cp_async_wait_all) { - if (header == synclog_header_cp_async_wait_all) { - synclog_print_prefix("cp_async_wait_all", at); - at += synclog_length_cp_async_wait_all; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_cp_async_fence) { - if (header == synclog_header_cp_async_fence) { - synclog_print_prefix("cp_async_fence", at); - at += synclog_length_cp_async_fence; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_cp_async_nan) { - if (header == synclog_header_cp_async_nan) { - synclog_print_prefix("cp_async_nan", at); - at += synclog_length_cp_async_nan; - uint64_t gmem_addr = synclog_buf[at-3]; - gmem_addr += (uint64_t)synclog_buf[at-2] << 32; - printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cp_async_zfill) { - if (header == synclog_header_cp_async_zfill) { - synclog_print_prefix("cp_async_zfill", at); - at += synclog_length_cp_async_zfill; - uint64_t gmem_addr = synclog_buf[at-4]; - gmem_addr += (uint64_t)synclog_buf[at-3] << 32; - printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_cp_async) { - if (header == synclog_header_cp_async) { - synclog_print_prefix("cp_async", at); - at += synclog_length_cp_async; - uint64_t gmem_addr = synclog_buf[at-4]; - gmem_addr += (uint64_t)synclog_buf[at-3] << 32; - printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_tma_load) { - if (header == synclog_header_tma_load) { - synclog_print_prefix("tma_load", at); - at += synclog_length_tma_load; - uint64_t gmem_int_desc = synclog_buf[at-4]; - gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32; - printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_tma_store) { - if (header == synclog_header_tma_store) { - synclog_print_prefix("tma_store", at); - at += synclog_length_tma_store; - uint64_t gmem_int_desc = synclog_buf[at-3]; - gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32; - printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_tma_store_arrive) { - if (header == synclog_header_tma_store_arrive) { - synclog_print_prefix("tma_store_arrive", at); - at += synclog_length_tma_store_arrive; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_tma_store_wait) { - if (header == synclog_header_tma_store_wait) { - synclog_print_prefix("tma_store_wait", at); - at += synclog_length_tma_store_wait; - printf("count=%u\n", synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_warpgroup_arrive) { - if (header == synclog_header_warpgroup_arrive) { - synclog_print_prefix("warpgroup_arrive", at); - at += synclog_length_warpgroup_arrive; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_warpgroup_wait) { - if (header == synclog_header_warpgroup_wait) { - synclog_print_prefix("warpgroup_wait", at); - at += synclog_length_warpgroup_wait; - printf("n=%u\n", synclog_buf[at-1]); - continue; - } - } - if constexpr (synclog_enable_warpgroup_commit_batch) { - if (header == synclog_header_warpgroup_commit_batch) { - synclog_print_prefix("warpgroup_commit_batch", at); - at += synclog_length_warpgroup_commit_batch; - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_wgmma_reg_smem) { - if (header == synclog_header_wgmma_reg_smem) { - synclog_print_prefix("wgmma_reg_smem", at); - at += synclog_length_wgmma_reg_smem; - synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_wgmma_smem_smem) { - if (header == synclog_header_wgmma_smem_smem) { - synclog_print_prefix("wgmma_smem_smem", at); - at += synclog_length_wgmma_smem_smem; - synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " "); - synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); - printf("\n"); - continue; - } - } - if constexpr (synclog_enable_cpasync_barrier_arrive) { - if (header == synclog_header_cpasync_barrier_arrive) { - synclog_print_prefix("cpasync_barrier_arrive", at); - at += synclog_length_cpasync_barrier_arrive; - printf("smem_addr=%u\n", synclog_buf[at-1]); - continue; - } - } - asm volatile ("brkpt;\n" ::); - } - if (synclog_buf[0] >= synclog_cap) { - printf( - "synclog was truncated (exceeded capacity of %lu bytes)\n", - (synclog_cap - 1) * sizeof(uint32_t) - ); - } - printf("synclog end\n"); - #endif - #endif // defined(CUTLASS_ENABLE_SYNCLOG) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -#if defined(CUTLASS_ENABLE_SYNCLOG) -#undef __syncthreads -#define __syncthreads() do {\ - cutlass::arch::synclog_emit_syncthreads(__LINE__);\ - __syncthreads();\ -} while (0) -#endif // defined(CUTLASS_ENABLE_SYNCLOG) - -#if defined(CUTLASS_ENABLE_SYNCLOG) -#undef __syncwarp -#define __syncwarp(...) do {\ - cutlass::arch::synclog_emit_syncwarp(__LINE__);\ - __syncwarp(__VA_ARGS__);\ -} while (0) -#endif // defined(CUTLASS_ENABLE_SYNCLOG) - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h deleted file mode 100644 index 2d4861ab682aca73d40ad5d0f298f9a265f7b9f2..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h +++ /dev/null @@ -1,218 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations -*/ - -#pragma once - -#if (__CUDACC_VER_MAJOR__ >= 9) -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) -#define CUTLASS_ARCH_WMMA_ENABLED -#define CUTLASS_ARCH_WMMA_SM70_ENABLED -#endif -#endif - -#if (__CUDACC_VER_MAJOR__ >= 10) -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) -#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -#define CUTLASS_ARCH_WMMA_SM72_ENABLED -#endif -#endif - -#if (__CUDACC_VER_MAJOR__ >= 10) -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) -#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -#define CUTLASS_ARCH_WMMA_SM75_ENABLED -#endif -#endif - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include -#include "cutlass/arch/mma.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/gemm/gemm.h" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////// -/// Statically maps cutlass data types => nvcuda::wmma data types -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct CutlassToWmmaDataType{ - using Type = Type_; -}; - -/// Statically maps cutlass::half_t => __half -template<> -struct CutlassToWmmaDataType { - using Type = __half; -}; - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -template<> -struct CutlassToWmmaDataType { - using Type = __nv_bfloat16; -}; -#endif - -/// Statically maps int8_t => char -template<> -struct CutlassToWmmaDataType { - using Type = signed char; -}; - -/// Statically maps uint8_t => char -template<> -struct CutlassToWmmaDataType { - using Type = unsigned char; -}; - -/// Statically maps int32_t => int -template<> -struct CutlassToWmmaDataType { - using Type = int; -}; - -#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -/// Statically maps cutlass::int4b_t => experimental::precision::s4 -template<> -struct CutlassToWmmaDataType { - using Type = nvcuda::wmma::experimental::precision::s4; -}; - -/// Statically maps cutlass::uint4b_t => experimental::precision::s4 -template<> -struct CutlassToWmmaDataType { - using Type = nvcuda::wmma::experimental::precision::u4; -}; - -/// Statically maps cutlass::uint1b_t => experimental::precision::b1 -template<> -struct CutlassToWmmaDataType { - using Type = nvcuda::wmma::experimental::precision::b1; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////// -/// Statically maps cutlass::layout => nvcuda::wmma layout tags -//////////////////////////////////////////////////////////////////////////////////////////////// -template -struct CutlassToWmmaLayout { -}; - -/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -template <> -struct CutlassToWmmaLayout { - using Layout = nvcuda::wmma::row_major; - static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////// -/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -//////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct CutlassToWmmaLayout { - using Layout = nvcuda::wmma::col_major; - static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major; -}; -//////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////// -/// Statically maps nvcuda::wmma data types => cutlass data types -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WmmaToCutlassDataType{ - using Type = Type_; -}; - -/// Statically maps __half => cutlass::half_t -template<> -struct WmmaToCutlassDataType<__half> { - using Type = cutlass::half_t; -}; - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -template<> -struct WmmaToCutlassDataType<__nv_bfloat16> { - using Type = cutlass::bfloat16_t; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks -// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]), -// and native wmma size (Shape) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename ElementA_, ///< Data type of A elements - typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) - typename ElementB_, ///< Data type of B elements - typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) - typename ElementC_, ///< Element type of C matrix - typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) - typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) -> -struct Wmma; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Specializations for each compute capability -// -#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -#include "cutlass/arch/wmma_sm70.h" -#endif - -#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -#include "cutlass/arch/wmma_sm72.h" -#endif - -#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -#include "cutlass/arch/wmma_sm75.h" -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#endif //CUTLASS_ARCH_WMMA_ENABLED diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h deleted file mode 100644 index 2c540be88577b448a2abc75cf6478736a41eb716..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h +++ /dev/null @@ -1,132 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) -#include "cutlass/layout/matrix.h" - -//////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace arch { - - -//////////////////////////////////////////////////////////////////////////////// -// -// WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for half -// -//////////////////////////////////////////////////////////////////////////////// -template < -typename Shape_, -typename LayoutA_, -typename LayoutB_, -typename ElementC_, -typename LayoutC_> -struct Wmma< - Shape_, ///< Size of the matrix product (concept: GemmShape) - cutlass::half_t, ///< ElementA - LayoutA_, ///< LayoutA - cutlass::half_t, ///< ElementB - LayoutB_, ///< LayoutB - ElementC_, ///< ElementC - LayoutC_, ///< LayoutC - cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -> { - -#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) - using Shape = Shape_; - using ElementA = cutlass::half_t; - using LayoutA = LayoutA_; - using ElementB = cutlass::half_t; - using LayoutB = LayoutB_; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using Operator = cutlass::arch::OpMultiplyAdd; - using ArchTag = arch::Sm70; - - // check supported wmma shape for the given multiplicand data types - static_assert( - platform::is_same, Shape>::value || - platform::is_same, Shape>::value || - platform::is_same, Shape>::value, - "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); - - // check supported wmma output data type for the given multiplicand data types - static_assert( - platform::is_same::value || platform::is_same::value, - "Supported of wmma output data type for f16 multiplicands are: f16 and f32"); - - // Wmma Fragment - using FragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type>; - - /// Performs a nvcuda::wmma matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - - nvcuda::wmma::mma_sync(D, A, B, C); - } -#else - static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond"); -#endif - -}; - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h deleted file mode 100644 index 1eb553e8f311e66e08e47dab15c6b08c29dec81c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h +++ /dev/null @@ -1,206 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) -#include "cutlass/layout/matrix.h" - -//////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// -// -// WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for int8_t -// -//////////////////////////////////////////////////////////////////////////////// -template < -typename Shape_, -typename LayoutA_, -typename LayoutB_, -typename LayoutC_> -struct Wmma< - Shape_, ///< Size of the matrix product (concept: GemmShape) - int8_t, ///< ElementA - LayoutA_, ///< LayoutA - int8_t, ///< ElementB - LayoutB_, ///< LayoutB - int32_t, ///< ElementC - LayoutC_, ///< LayoutC - cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -> { -#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) - using Shape = Shape_; - using ElementA = int8_t; - using LayoutA = LayoutA_; - using ElementB = int8_t; - using LayoutB = LayoutB_; - using ElementC = int32_t; - using LayoutC = LayoutC_; - using Operator = cutlass::arch::OpMultiplyAdd; - using ArchTag = arch::Sm72; - - // check supported wmma shape for the given multiplicand data types - static_assert( - platform::is_same, Shape>::value || - platform::is_same, Shape>::value || - platform::is_same, Shape>::value, - "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); - - - // Wmma Fragment - using FragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type>; - - /// Performs a nvcuda::wmma matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - - nvcuda::wmma::mma_sync(D, A, B, C); - } - -#else - static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); -#endif - -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for uint8_t -// -//////////////////////////////////////////////////////////////////////////////// -template < -typename Shape_, -typename LayoutA_, -typename LayoutB_, -typename LayoutC_> -struct Wmma< - Shape_, ///< Size of the matrix product (concept: GemmShape) - uint8_t, ///< ElementA - LayoutA_, ///< LayoutA - uint8_t, ///< ElementB - LayoutB_, ///< LayoutB - int32_t, ///< ElementC - LayoutC_, ///< LayoutC - cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -> { -#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) - using Shape = Shape_; - using ElementA = uint8_t; - using LayoutA = LayoutA_; - using ElementB = uint8_t; - using LayoutB = LayoutB_; - using ElementC = int32_t; - using LayoutC = LayoutC_; - using Operator = cutlass::arch::OpMultiplyAdd; - using ArchTag = arch::Sm72; - - // check supported wmma shape for the given multiplicand data types - static_assert( - platform::is_same, Shape>::value || - platform::is_same, Shape>::value || - platform::is_same, Shape>::value, - "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); - - // Wmma Fragment - using FragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type>; - - /// Performs a nvcuda::wmma matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - - nvcuda::wmma::mma_sync(D, A, B, C); - } - -#else - static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); -#endif - -}; - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h deleted file mode 100644 index c3535ef0748e53b204b7d20cdd4aa82edc8c72a8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h +++ /dev/null @@ -1,203 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Matrix multiply -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) -#include "cutlass/layout/matrix.h" - -//////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace arch { - -//////////////////////////////////////////////////////////////////////////////// -// -// WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4). -// -//////////////////////////////////////////////////////////////////////////////// -template < -typename Shape_, -typename LayoutA_, -typename LayoutB_, -typename LayoutC_> -struct Wmma< - Shape_, ///< Size of the matrix product (concept: GemmShape) - cutlass::int4b_t, ///< ElementA - LayoutA_, ///< LayoutA - cutlass::int4b_t, ///< ElementB - LayoutB_, ///< LayoutB - int32_t, ///< ElementC - LayoutC_, ///< LayoutC - cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -> { -#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) - using Shape = Shape_; - using ElementA = cutlass::int4b_t; - using LayoutA = LayoutA_; - using ElementB = cutlass::int4b_t; - using LayoutB = LayoutB_; - using ElementC = int32_t; - using LayoutC = LayoutC_; - using Operator = cutlass::arch::OpMultiplyAdd; - using ArchTag = arch::Sm75; - - // check supported wmma shape for the given multiplicand data types - static_assert( - platform::is_same, Shape>::value, - "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32"); - - - // Wmma Fragment - using FragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type>; - - /// Performs a nvcuda::wmma matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - nvcuda::wmma::mma_sync(D, A, B, C); - - } - -#else - static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); -#endif - -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). -// -//////////////////////////////////////////////////////////////////////////////// -template < -typename Shape_, -typename LayoutA_, -typename LayoutB_, -typename LayoutC_> -struct Wmma< - Shape_, ///< Size of the matrix product (concept: GemmShape) - cutlass::uint1b_t, ///< ElementA - LayoutA_, ///< LayoutA - cutlass::uint1b_t, ///< ElementB - LayoutB_, ///< LayoutB - int32_t, ///< ElementC - LayoutC_, ///< LayoutC - cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc) -> { -#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) - using Shape = Shape_; - using ElementA = cutlass::uint1b_t; - using LayoutA = LayoutA_; - using ElementB = cutlass::uint1b_t; - using LayoutB = LayoutB_; - using ElementC = int32_t; - using LayoutC = LayoutC_; - using Operator = cutlass::arch::OpXorPopc; - using ArchTag = arch::Sm75; - - // check supported wmma shape for the given multiplicand data types - static_assert( - platform::is_same, Shape>::value, - "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128"); - - - // Wmma Fragment - using FragmentA = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_a, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentB = nvcuda::wmma::fragment< - nvcuda::wmma::matrix_b, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type, - typename CutlassToWmmaLayout::Layout>; - - using FragmentC = nvcuda::wmma::fragment< - nvcuda::wmma::accumulator, - Shape::kM, - Shape::kN, - Shape::kK, - typename CutlassToWmmaDataType::Type>; - - /// Performs a nvcuda::wmma matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, - nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); - } - -#else - static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); -#endif - -}; - -} // namespace arch -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h deleted file mode 100644 index ce33110aa4f44e7deba56a5f9fe4db206a6889ce..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h +++ /dev/null @@ -1,2860 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types - and is safe to use in a union. -*/ - -#pragma once -#include "cutlass/cutlass.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_types.h" -#include "cutlass/platform/platform.h" -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically sized array for any data type -template < - typename T, - int N, - bool RegisterSized = sizeof_bits::value >= 32 -> -struct Array; - -namespace detail { - -template -struct is_Array : platform::false_type {}; - -template < - typename T, - int N, - bool RegisterSized -> -struct is_Array > : platform::true_type {}; - -template -constexpr bool is_Array_v = is_Array::value; - -} // namespace detail - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines the size of an Array<> in bits -template -struct sizeof_bits > { - static constexpr int value = sizeof(Array) * 8; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if the argument is a power of 2 -CUTLASS_HOST_DEVICE -constexpr bool ispow2(unsigned x) { - return x && (!(x & (x - 1))); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the largest power of two not greater than the argument. -CUTLASS_HOST_DEVICE -constexpr unsigned floor_pow_2(unsigned x) { - return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically sized array for any data type -template < - typename T, - int N -> -struct Array { - - /// Storage type - using Storage = T; - - /// Element type - using Element = T; - - /// Number of storage elements - //static std::size_t const kStorageElements = N; - static constexpr size_t kStorageElements = N; - - /// Number of logical elements - static constexpr size_t kElements = N; - - // - // C++ standard members - // - - typedef T value_type; - typedef size_t size_type; - typedef ptrdiff_t difference_type; - typedef value_type &reference; - typedef value_type const & const_reference; - typedef value_type *pointer; - typedef value_type const * const_pointer; - - // - // Iterators - // - - /// Bidirectional iterator over elements - class iterator { - - /// Pointer to object - T *ptr_; - - public: - - CUTLASS_HOST_DEVICE - iterator(): ptr_(nullptr) { } - - CUTLASS_HOST_DEVICE - iterator(T *_ptr): ptr_(_ptr) { } - - CUTLASS_HOST_DEVICE - iterator &operator++() { - ++ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - iterator &operator--() { - --ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - iterator operator++(int) { - iterator ret(*this); - ++ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - iterator operator--(int) { - iterator ret(*this); - --ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - T &operator*() const { - return *ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator==(iterator const &other) const { - return ptr_ == other.ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(iterator const &other) const { - return ptr_ != other.ptr_; - } - }; - - /// Bidirectional constant iterator over elements - class const_iterator { - - /// Pointer to object - const T *ptr_; - - public: - - CUTLASS_HOST_DEVICE - const_iterator(): ptr_(nullptr) { } - - CUTLASS_HOST_DEVICE - const_iterator(T const *_ptr): ptr_(_ptr) { } - - CUTLASS_HOST_DEVICE - const_iterator &operator++() { - ++ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - const_iterator &operator--() { - --ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - const_iterator operator++(int) { - const_iterator ret(*this); - ++ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - const_iterator operator--(int) { - const_iterator ret(*this); - --ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - T const &operator*() const { - return *ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator==(const_iterator const &other) const { - return ptr_ == other.ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(const_iterator const &other) const { - return ptr_ != other.ptr_; - } - }; - - /// Bidirectional iterator over elements - class reverse_iterator { - - /// Pointer to object - T *ptr_; - - public: - - CUTLASS_HOST_DEVICE - reverse_iterator(): ptr_(nullptr) { } - - CUTLASS_HOST_DEVICE - reverse_iterator(T *_ptr): ptr_(_ptr) { } - - CUTLASS_HOST_DEVICE - reverse_iterator &operator++() { - --ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - reverse_iterator &operator--() { - ++ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - reverse_iterator operator++(int) { - iterator ret(*this); - --ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - reverse_iterator operator--(int) { - iterator ret(*this); - ++ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - T &operator*() const { - return *(ptr_ - 1); - } - - CUTLASS_HOST_DEVICE - bool operator==(reverse_iterator const &other) const { - return ptr_ == other.ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(reverse_iterator const &other) const { - return ptr_ != other.ptr_; - } - }; - - /// Bidirectional constant iterator over elements - class const_reverse_iterator { - - /// Pointer to object - T const *ptr_; - - public: - - CUTLASS_HOST_DEVICE - const_reverse_iterator(): ptr_(nullptr) { } - - CUTLASS_HOST_DEVICE - const_reverse_iterator(T const *_ptr): ptr_(_ptr) { } - - CUTLASS_HOST_DEVICE - const_reverse_iterator &operator++() { - --ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator &operator--() { - ++ptr_; - return *this; - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator operator++(int) { - const_reverse_iterator ret(*this); - --ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator operator--(int) { - const_reverse_iterator ret(*this); - ++ptr_; - return ret; - } - - CUTLASS_HOST_DEVICE - T const &operator*() const { - return *(ptr_ - 1); - } - - CUTLASS_HOST_DEVICE - bool operator==(const_iterator const &other) const { - return ptr_ == other.ptr_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(const_iterator const &other) const { - return ptr_ != other.ptr_; - } - }; - - /// Internal storage - Storage storage[kElements]; - - /// Efficient clear method - CUTLASS_HOST_DEVICE - void clear() { - fill(T(0)); - } - - CUTLASS_HOST_DEVICE - reference at(size_type pos) { - return reinterpret_cast(storage[pos]); - } - - CUTLASS_HOST_DEVICE - const_reference at(size_type pos) const { - return reinterpret_cast(storage[pos]); - } - - CUTLASS_HOST_DEVICE - reference operator[](size_type pos) { - return reinterpret_cast(storage[pos]); - } - - CUTLASS_HOST_DEVICE - const_reference operator[](size_type pos) const { - return reinterpret_cast(storage[pos]); - } - - CUTLASS_HOST_DEVICE - reference front() { - return reinterpret_cast(storage[0]); - } - - CUTLASS_HOST_DEVICE - const_reference front() const { - return reinterpret_cast(storage[0]); - } - - CUTLASS_HOST_DEVICE - reference back() { - return reinterpret_cast(storage[kStorageElements - 1]); - } - - CUTLASS_HOST_DEVICE - const_reference back() const { - return reinterpret_cast(storage[kStorageElements - 1]); - } - - CUTLASS_HOST_DEVICE - pointer data() { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - const_pointer data() const { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - pointer raw_data() { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - const_pointer raw_data() const { - return reinterpret_cast(storage); - } - - - CUTLASS_HOST_DEVICE - constexpr bool empty() const { - return !kElements; - } - - CUTLASS_HOST_DEVICE - constexpr size_type size() const { - return kElements; - } - - CUTLASS_HOST_DEVICE - constexpr size_type max_size() const { - return kElements; - } - - CUTLASS_HOST_DEVICE - void fill(T const &value) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(kElements); ++i) { - storage[i] = static_cast(value); - } - } - - CUTLASS_HOST_DEVICE - iterator begin() { - return iterator(storage); - } - - CUTLASS_HOST_DEVICE - const_iterator begin() const { - return cbegin(); - } - - CUTLASS_HOST_DEVICE - const_iterator cbegin() const { - return const_iterator(storage); - } - - CUTLASS_HOST_DEVICE - iterator end() { - return iterator(reinterpret_cast(storage + kStorageElements)); - } - - CUTLASS_HOST_DEVICE - const_iterator end() const { - return cend(); - } - - CUTLASS_HOST_DEVICE - const_iterator cend() const { - return const_iterator(reinterpret_cast(storage + kStorageElements)); - } - - CUTLASS_HOST_DEVICE - reverse_iterator rbegin() { - return reverse_iterator(reinterpret_cast(storage + kStorageElements)); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator rbegin() const { - return crbegin(); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator crbegin() const { - return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); - } - - CUTLASS_HOST_DEVICE - reverse_iterator rend() { - return reverse_iterator(reinterpret_cast(storage)); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator rend() const { - return crend(); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator crend() const { - return const_reverse_iterator(reinterpret_cast(storage)); - } - - // - // Comparison operators - // - -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Factories -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE -Array make_Array(Element x) { - return {x}; -} - -template -CUTLASS_HOST_DEVICE -Array make_Array(Element x, Element y) { - return {x,y}; -} - -template -CUTLASS_HOST_DEVICE -Array make_Array(Element x, Element y, Element z) { - return {x,y,z}; -} - -template -CUTLASS_HOST_DEVICE -Array make_Array(Element x, Element y, Element z, Element w) { - return {x,y,z,w}; -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// functional.h numeric specializations -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct absolute_value_op< Array > { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - absolute_value_op scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -template -struct plus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; -template -struct minus> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct multiplies> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct maximum_absolute_value_reduction, PropogateNaN> { - - CUTLASS_HOST_DEVICE - T operator() (T const& scalar, Array const& rhs) const { - - T result = scalar; - maximum_absolute_value_reduction scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result = scalar_op(result, rhs[i]); - } - - return result; - } -}; - -template -struct scale> { - T const scaling_factor_; - - CUTLASS_HOST_DEVICE - scale(T scaling_factor) : scaling_factor_(scaling_factor) { - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & rhs) const { - Array result; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = rhs[i] * scaling_factor_; - } - - return result; - } -}; - -template -struct divides> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct reciprocal_approximate> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - reciprocal_approximate scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -template -struct reciprocal_approximate_ftz> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - reciprocal_approximate_ftz scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -template -struct maximum, PropagateNaN> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum, PropagateNaN> { - - CUTLASS_HOST_DEVICE - static T scalar_op(T const &lhs, T const &rhs) { - return (rhs < lhs ? rhs : lhs); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum_with_nan_propagation> : minimum, true> -{}; - -template -struct negate> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - negate scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar, c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, T const &scalar) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], scalar); - } - - return result; - } - - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar_b, T const &scalar_c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar_b, scalar_c); - } - - return result; - } -}; - -/// Fused square-and-plus -template -struct square_and_plus> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - multiply_add, Array, Array> ma_op; - return ma_op(rhs, rhs, lhs); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &rhs) const { - plus> plus_op; - multiplies multiplies_op; - return plus_op(multiplies_op(rhs, rhs), lhs); - } -}; - -/// Inverse-square-root -template -struct inverse_square_root> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a) const { - Array result; - inverse_square_root scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i]); - } - return result; - } -}; - -template -struct inverse_square_root> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & a) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = h2rsqrt(a_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half d_residual = hrsqrt(a_residual_ptr[N - 1]); - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - inverse_square_root scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i]); - } - - #endif - - return result; - } -}; - -/// Fused multiply-add-relu0 -template -struct multiply_add_relu0, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); - } - - return result; - } -}; - - -template -struct conjugate > { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a) const { - - conjugate conj_op; - - Array ca; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - ca[i] = conj_op(a[i]); - } - return ca; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// functional.h numeric specializations targeting SIMD instructions in device code. -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct plus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] + rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs + rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] + rhs; - } - #endif - - return result; - } -}; - -template -struct minus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] - rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs - rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] - rhs; - } - #endif - - return result; - } -}; - -template -struct multiplies> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] * rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmul( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs * rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hmul( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] * rhs; - } - #endif - - return result; - } -}; - -template -struct divides> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hdiv( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] / rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hdiv( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs / rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hdiv( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] / rhs; - } - #endif - - return result; - } -}; - -template -struct negate> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hneg2(source_ptr[i]); - } - - if constexpr (N % 2) { - half_t x = -lhs[N - 1]; - __half lhs_val = reinterpret_cast<__half const &>(x); - result[N - 1] = reinterpret_cast(lhs_val); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = -lhs[i]; - } - #endif - - return result; - } -}; - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - half_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - __half d_residual = __hfma( - reinterpret_cast<__half const &>(a), - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a, b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - half_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(b), - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - half_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - reinterpret_cast<__half const &>(c)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - half_t const &b, - half_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_pair); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(b), - reinterpret_cast<__half const &>(c)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c); - } - #endif - - return result; - } -}; - -/// Fused multiply-add-relu0 -template -struct multiply_add_relu0, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - half_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - __half d_residual = __hfma_relu( - reinterpret_cast<__half const &>(a), - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a, b[i], c[i]), half_t(0)); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - half_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(b), - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b, c[i]), half_t(0)); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - half_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); - } - - if constexpr (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - reinterpret_cast<__half const &>(c)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b[i], c), half_t(0)); - } - #endif - - return result; - } -}; - -template -struct minimum, PropagateNaN> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i]) - : __hmin2(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) - : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - minimum mn; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mn(lhs[i],rhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i]) - : __hmin2(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) - : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - minimum mn; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mn(lhs, rhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair) - : __hmin2(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) - : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - minimum mn; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mn(lhs[i], rhs); - } - #endif - - return result; - } -}; - -template -struct maximum, PropagateNaN> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i]) - : __hmax2(lhs_ptr[i], rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) - : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(lhs[i], rhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i]) - : __hmax2(lhs_pair, rhs_ptr[i]); - } - - if constexpr (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) - : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(lhs, rhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair) - : __hmax2(lhs_ptr[i], rhs_pair); - } - - if constexpr (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) - : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(lhs[i], rhs); - } - #endif - - return result; - } -}; - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *b_ptr = reinterpret_cast(&b); - unsigned const *c_ptr = reinterpret_cast(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) - ); - } - - if constexpr (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - bfloat16_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *b_ptr = reinterpret_cast(&b); - unsigned const *c_ptr = reinterpret_cast(&c); - - unsigned a_packed = static_cast(a.raw()); - a_packed = (a_packed | (a_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) - ); - } - - if constexpr (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a, b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - bfloat16_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *c_ptr = reinterpret_cast(&c); - - unsigned b_packed = static_cast(b.raw()); - b_packed = (b_packed | (b_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) - ); - } - - if constexpr (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - bfloat16_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *b_ptr = reinterpret_cast(&b); - - unsigned c_packed = static_cast(c.raw()); - c_packed = (c_packed | (c_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) - ); - } - - if constexpr (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - bfloat16_t const &b, - bfloat16_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *a_ptr = reinterpret_cast(&a); - - unsigned b_packed = static_cast(b.raw()); - b_packed = (b_packed | (b_packed << 16)); - - unsigned c_packed = static_cast(c.raw()); - c_packed = (c_packed | (c_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_packed), "r"(c_packed) - ); - } - - if constexpr (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[0]) - ); - } - - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c); - } - #endif - - return result; - } -}; - - -/// bit_and -template -struct bit_and> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] & b_data[i]); - } - - return result; - } -}; - - -/// bit_or -template -struct bit_or> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] | b_data[i]); - } - - return result; - } -}; - - -/// bit_not -template -struct bit_not> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (~a_data[i]); - } - - return result; - } -}; - -/// bit_xor -template -struct bit_xor> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] ^ b_data[i]); - } - - return result; - } -}; - -/// Fused and-popc-add -template -struct and_popc_add, Array, Array> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - Array result; - and_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - Array result; - and_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar, c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - Array result; - and_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, b[i], c[i]); - } - - return result; - } -}; - - -/// Fused or-popc-add -template -struct or_popc_add, Array, Array> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - Array result; - or_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - Array result; - or_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar, c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - Array result; - or_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, b[i], c[i]); - } - - return result; - } -}; - -/// Fused xor-popc-add -template -struct xor_popc_add, Array, Array> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - Array result; - xor_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - Array result; - xor_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar, c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - Array result; - xor_popc_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, b[i], c[i]); - } - - return result; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Operator overloads -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE -Array operator+(Array const &lhs, Array const &rhs) { - plus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator+(T const &lhs, Array const &rhs) { - plus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator+(Array const &lhs, T const &rhs) { - plus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator-(Array const &lhs, Array const &rhs) { - minus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator-(Array const &lhs) { - negate> op; - return op(lhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(Array const &lhs, Array const &rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(T lhs, Array const &rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(Array const &lhs, T rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator/(Array const &lhs, Array const &rhs) { - divides> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, Array const &b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(T a, Array const &b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, T b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, Array const &b, T c) { - multiply_add> op; - return op(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// AlignedArray -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Aligned array type -template < - /// Element type - typename T, - /// Number of elements in the array - int N, - /// Alignment requirement in bytes - int Alignment = ( sizeof_bits::value * N + 7 ) / 8 -> -class alignas(Alignment) AlignedArray: public Array { -public: - -}; - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/array_subbyte.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h deleted file mode 100644 index 0bd9d0d7f7dc709b951c6979a3e26cf05ba9c79d..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h +++ /dev/null @@ -1,89 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Array holding planar complex elements -template -struct ArrayPlanarComplex { - - /// Underlying real element - using Element = Element_; - - /// Number of logical elements - static constexpr size_t kElements = N; - - /// Underlying Fragment of real-valued elemenets - using ArrayReal = cutlass::Array; - -public: - /// Fragment of real-valued elements representing the real part - ArrayReal real; - - /// Fragment of real-valued elements representing the imaginary part - ArrayReal imag; - -public: - /// Sets the array to zero efficiently - CUTLASS_HOST_DEVICE - void clear() { - real.clear(); - imag.clear(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to deduce template arguments -template -CUTLASS_HOST_DEVICE -ArrayPlanarComplex -make_ArrayPlanarComplex(Array const &real, Array const &imag) { - return ArrayPlanarComplex{real, imag}; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h deleted file mode 100644 index 756890bb61f7ff5f2a9912b00b98a54deae6ee75..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h +++ /dev/null @@ -1,561 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types - and is safe to use in a union. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically sized array for any data type -template < - typename T, - int N -> -struct Array { - static constexpr int kSizeBits = sizeof_bits::value * N; - - /// Storage type - using Storage = typename platform::conditional< - ((kSizeBits % 32) != 0), - typename platform::conditional< - ((kSizeBits % 16) != 0), - uint8_t, - uint16_t - >::type, - uint32_t - >::type; - - /// Element type - using Element = T; - - /// Number of logical elements per stored object - static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; - - /// Number of storage elements - static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; - - /// Number of logical elements - static constexpr size_t kElements = N; - - /// Bitmask for covering one item - static constexpr Storage kMask = ((Storage(1) << sizeof_bits::value) - 1); - - // - // C++ standard members with pointer types removed - // - - typedef T value_type; - typedef size_t size_type; - typedef ptrdiff_t difference_type; - typedef value_type *pointer; - typedef value_type const *const_pointer; - - // - // References - // - - /// Reference object inserts or extracts sub-byte items - class reference { - /// Pointer to storage element - Storage *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - reference() = default; - - /// Ctor - CUTLASS_HOST_DEVICE - reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - /// Assignment - CUTLASS_HOST_DEVICE - reference &operator=(T x) { - // `*ptr_ & kUpdateMask` will read ptr_ before write to it - // This means code pattern like - // - // ```cpp - // Array result; - // result[0] = xxx; - // ``` - // - // Will leads to compiler warning on use of uninitialized member variable. Although we know - // this read of uninitialized member variable is harmeless. - -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wuninitialized" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wuninitialized" -# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#endif - - Storage item = (reinterpret_cast(x) & kMask); - - Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); - - *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return *this; - } - - CUTLASS_HOST_DEVICE - T get() const { - Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); - return reinterpret_cast(item); - } - - /// Extract - CUTLASS_HOST_DEVICE - operator T() const { - return get(); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - }; - - /// Reference object extracts sub-byte items - class const_reference { - - /// Pointer to storage element - Storage const *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - const_reference() = default; - - /// Ctor - CUTLASS_HOST_DEVICE - const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - CUTLASS_HOST_DEVICE - const T get() const { - Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; - return reinterpret_cast(item); - } - - /// Extract - CUTLASS_HOST_DEVICE - operator T() const { - Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); - return reinterpret_cast(item); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - }; - - // - // Iterators - // - - /// Bidirectional iterator over elements - class iterator { - - /// Pointer to storage element - Storage *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - iterator() = default; - - CUTLASS_HOST_DEVICE - iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - CUTLASS_HOST_DEVICE - iterator &operator++() { - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return *this; - } - - CUTLASS_HOST_DEVICE - iterator &operator--() { - if (!idx_) { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - else { - --idx_; - } - return *this; - } - - CUTLASS_HOST_DEVICE - iterator operator++(int) { - iterator ret(*this); - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return ret; - } - - CUTLASS_HOST_DEVICE - iterator operator--(int) { - iterator ret(*this); - if (!idx_) { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - else { - --idx_; - } - return ret; - } - - CUTLASS_HOST_DEVICE - reference operator*() const { - return reference(ptr_, idx_); - } - - CUTLASS_HOST_DEVICE - bool operator==(iterator const &other) const { - return ptr_ == other.ptr_ && idx_ == other.idx_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(iterator const &other) const { - return !(*this == other); - } - }; - - /// Bidirectional constant iterator over elements - class const_iterator { - - /// Pointer to storage element - Storage const *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - const_iterator() = default; - - CUTLASS_HOST_DEVICE - const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - CUTLASS_HOST_DEVICE - iterator &operator++() { - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return *this; - } - - CUTLASS_HOST_DEVICE - iterator &operator--() { - if (!idx_) { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - else { - --idx_; - } - return *this; - } - - CUTLASS_HOST_DEVICE - iterator operator++(int) { - iterator ret(*this); - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return ret; - } - - CUTLASS_HOST_DEVICE - iterator operator--(int) { - iterator ret(*this); - if (!idx_) { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - else { - --idx_; - } - return ret; - } - - CUTLASS_HOST_DEVICE - const_reference operator*() const { - return const_reference(ptr_, idx_); - } - - CUTLASS_HOST_DEVICE - bool operator==(iterator const &other) const { - return ptr_ == other.ptr_ && idx_ == other.idx_; - } - - CUTLASS_HOST_DEVICE - bool operator!=(iterator const &other) const { - return !(*this == other); - } - }; - - /// Bidirectional iterator over elements - class reverse_iterator { - - /// Pointer to storage element - Storage *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - reverse_iterator() = default; - - CUTLASS_HOST_DEVICE - reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - }; - - /// Bidirectional constant iterator over elements - class const_reverse_iterator { - - /// Pointer to storage element - Storage const *ptr_{nullptr}; - - /// Index into elements packed into Storage object - int idx_{0}; - - public: - - const_reverse_iterator() = default; - - CUTLASS_HOST_DEVICE - const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - }; - - /// Efficient clear method - CUTLASS_HOST_DEVICE - void clear() { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(kStorageElements); ++i) { - storage[i] = Storage(0); - } - } - - CUTLASS_HOST_DEVICE - reference at(size_type pos) { - return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); - } - - CUTLASS_HOST_DEVICE - const_reference at(size_type pos) const { - return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); - } - - CUTLASS_HOST_DEVICE - reference operator[](size_type pos) { - return at(pos); - } - - CUTLASS_HOST_DEVICE - const_reference operator[](size_type pos) const { - return at(pos); - } - - CUTLASS_HOST_DEVICE - reference front() { - return at(0); - } - - CUTLASS_HOST_DEVICE - const_reference front() const { - return at(0); - } - - CUTLASS_HOST_DEVICE - reference back() { - return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); - } - - CUTLASS_HOST_DEVICE - const_reference back() const { - return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); - } - - CUTLASS_HOST_DEVICE - pointer data() { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - const_pointer data() const { - return reinterpret_cast(storage); - } - - CUTLASS_HOST_DEVICE - Storage * raw_data() { - return storage; - } - - CUTLASS_HOST_DEVICE - Storage const * raw_data() const { - return storage; - } - - CUTLASS_HOST_DEVICE - constexpr bool empty() const { - return !kElements; - } - - CUTLASS_HOST_DEVICE - constexpr size_type size() const { - return kElements; - } - - CUTLASS_HOST_DEVICE - constexpr size_type max_size() const { - return kElements; - } - - CUTLASS_HOST_DEVICE - void fill(T const &value) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerStoredItem; ++i) { - reference ref(storage, i); - ref = value; - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kStorageElements; ++i) { - storage[i] = storage[0]; - } - } - - CUTLASS_HOST_DEVICE - iterator begin() { - return iterator(storage); - } - - CUTLASS_HOST_DEVICE - const_iterator cbegin() const { - return const_iterator(storage); - } - - CUTLASS_HOST_DEVICE - iterator end() { - return iterator(storage + kStorageElements); - } - - CUTLASS_HOST_DEVICE - const_iterator cend() const { - return const_iterator(storage + kStorageElements); - } - - CUTLASS_HOST_DEVICE - reverse_iterator rbegin() { - return reverse_iterator(storage + kStorageElements); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator crbegin() const { - return const_reverse_iterator(storage + kStorageElements); - } - - CUTLASS_HOST_DEVICE - reverse_iterator rend() { - return reverse_iterator(storage); - } - - CUTLASS_HOST_DEVICE - const_reverse_iterator crend() const { - return const_reverse_iterator(storage); - } - -private: - /// Internal storage - Storage storage[kStorageElements]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h deleted file mode 100644 index 8919e992af20ac2d7f2b5daa8a0cbd7a6f7b79e5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h +++ /dev/null @@ -1,377 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -namespace detail { - -// -// Utilities for abstracting synchronization methods for barriers -// - -struct SyncthreadsSync { - CUTLASS_DEVICE - static void sync() { - __syncthreads(); - } -}; - -struct SyncwarpSync { - CUTLASS_DEVICE - static void sync() { - __syncwarp(); - } -}; - -template < - int ThreadCount, - int BarrierId -> -struct NamedBarrierSync { - CUTLASS_DEVICE - static void sync() { - cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast(BarrierId)); - } -}; - -} // namepspace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Group or CTA-wide semaphore for inter-CTA synchronization. -template -struct GenericBarrier { - -public: - - /// Flag type - using T = int; - - /// Initial flag value - static const T INIT = 0; - - -protected: - - /// Load flag, as a strong acquire operation (int specialization) - CUTLASS_DEVICE - static int ld_acquire(int *ptr) - { - int state = 0; - -#if (__CUDA_ARCH__ >= 700) - /// SM70 and newer use memory consistency qualifiers - - // Acquire pattern using acquire modifier - asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); - -#else - asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); -#endif // (__CUDA_ARCH__ >= 700) - - return state; - } - - - /// Reduce into flag, with release pattern (int specialization) - CUTLASS_DEVICE - static void red_release(int *ptr, int val) - { -#if (__CUDA_ARCH__ >= 700) - /// SM70 and newer use memory consistency qualifiers - - // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data - // that was weakly-written by other threads prior to the last syncthreads) - asm volatile ("fence.acq_rel.gpu;\n"); - asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); - -#else - __threadfence(); - atomicAdd(ptr, val); -#endif // (__CUDA_ARCH__ >= 700) - } - - -public: - - /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter - CUTLASS_DEVICE - static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) - { - T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - - if (thread_idx == 0) - { - // Spin-loop - #pragma unroll 1 - while(ld_acquire(flag_ptr) < count) {} - } - - Sync::sync(); - } - - /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter - CUTLASS_DEVICE - static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) - { - T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - - if (thread_idx == 0) - { - // Spin-loop - #pragma unroll 1 - while(ld_acquire(flag_ptr) != val) {} - } - Sync::sync(); - } - - /// Uses thread[0] to wait for the specified count of signals on the given flag counter - CUTLASS_DEVICE - static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { - T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - - if (thread_idx == 0) - { - // Spin-loop - #pragma unroll 1 - while(atomicCAS(flag_ptr, val, 0) != val) {} - } - - Sync::sync(); - } - - /// Increment the arrival count for a flag - CUTLASS_DEVICE - static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1) - { - T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - - Sync::sync(); - - if (thread_idx == 0) - { - red_release(flag_ptr, val); - } - } - - - /// Increment the arrival counts for a range of flags - CUTLASS_DEVICE - static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) - { - int flag_idx = first_flag_idx + thread_idx; - T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - - // Barrier to make sure all other threads in group have written their data - Sync::sync(); - - // Select threads increment their flags - if (thread_idx < count) { - red_release(flag_ptr, val); - } - } -}; - -using Barrier = GenericBarrier; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing - * runtime index values to be used to call into named barriers with compile-time-constant IDs. - * - * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID - * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into - * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type -**/ -template < - uint32_t ThreadCount_, - uint32_t Offset = 0, - uint32_t MaxNumNamedBarriers = 16 -> -struct NamedBarrierManager { - - static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers); - static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); - - // Number of threads participating in the barrier - static constexpr uint32_t ThreadCount = ThreadCount_; - - template - using BarrierSync = cutlass::GenericBarrier>; - - // Underlying type used by all barriers for synchronization. Does not depend on - // template parameter BarrierId, so passing in 0 suffices. - using T = typename BarrierSync<0>::T; - - using IntegerSequence = cute::make_integer_sequence; - - CUTLASS_DEVICE - static - void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) { - wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{}); - } - - CUTLASS_DEVICE - static void - wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { - wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); - } - - CUTLASS_DEVICE - static void - wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { - wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); - } - - CUTLASS_DEVICE - static void - arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { - arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); - } - - CUTLASS_DEVICE - static void - arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { - arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{}); - } - -private: - CUTLASS_DEVICE - static void - check_barrier_in_range([[maybe_unused]] uint32_t idx) { - assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count"); - } - - template - CUTLASS_DEVICE - static void - wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence) { - check_barrier_in_range(idx); - ((Idx == idx && (BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...); - } - - template - CUTLASS_DEVICE - static void - wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence) { - check_barrier_in_range(idx); - if constexpr (Reset) { - ((Idx == idx && (BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...); - } - else { - ((Idx == idx && (BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...); - } - } - - template - CUTLASS_DEVICE - static void - arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence) { - check_barrier_in_range(idx); - ((Idx == idx && (BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...); - } - - template - CUTLASS_DEVICE - static void - arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence) { - check_barrier_in_range(idx); - ((Idx == idx && (BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads) - * via an API that mirrors that of NamedBarrierManager - * - * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization -**/ -template < - class Synchronizer, - uint32_t ThreadCount_ -> -struct SyncManager { - - // Number of threads participating in the barrier - static constexpr uint32_t ThreadCount = ThreadCount_; - - using BarrierSync = cutlass::GenericBarrier; - - // Underlying type used by all barriers for synchronization. - using T = typename BarrierSync::T; - - CUTLASS_DEVICE - static - void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { - BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count); - } - - CUTLASS_DEVICE - static void - wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { - BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val); - } - - CUTLASS_DEVICE - static void - wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { - BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val); - } - - CUTLASS_DEVICE - static void - arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { - BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val); - } - - CUTLASS_DEVICE - static void - arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { - BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h deleted file mode 100644 index 5e2f40b1c85e24eb2bdeedb191529d53539f050c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h +++ /dev/null @@ -1,679 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines a proxy class for storing non-standard 16-bit floating point values with - 8 bits of exponent and 7 bit of mantissa. -*/ - -#pragma once - -#if defined(__CUDACC_RTC__) -#include "cutlass/floating_point_nvrtc.h" -#else -#include -#include -#include -#include -#endif - -#include -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. -struct alignas(2) bfloat16_t { - - // - // Data members - // - - /// Storage type - uint16_t storage; - - // - // Methods - // - - /// Constructs from an unsigned short - CUTLASS_HOST_DEVICE - static bfloat16_t bitcast(uint16_t x) { - bfloat16_t h; - h.storage = x; - return h; - } - -private: - struct from_32_bit_integer_t {}; - static constexpr from_32_bit_integer_t from_32_bit_integer{}; - - template - CUTLASS_HOST_DEVICE - explicit bfloat16_t(from_32_bit_integer_t, T x) { - static_assert(cutlass::platform::is_integral::value && sizeof(T) == 4, "Requires 32-bit integer"); - - float flt = static_cast(x); - uint32_t bits; - - #if defined(__CUDA_ARCH__) - bits = reinterpret_cast(flt); - #else - std::memcpy(&bits, &flt, sizeof(bits)); - #endif - - storage = uint16_t(bits >> 16); - } - -public: - /// Default constructor - bfloat16_t() = default; - - /// Reinterpret cast from CUDA's __nv_bfloat16 type - CUTLASS_HOST_DEVICE - explicit bfloat16_t(__nv_bfloat16 const & x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - __nv_bfloat16_raw raw(x); - std::memcpy(&storage, &raw.x, sizeof(storage)); - #endif - } - - /// Floating-point conversion - round toward nearest - CUTLASS_HOST_DEVICE - explicit bfloat16_t(float x) { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) - - asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); - - #else - uint32_t bits; - - #if defined(__CUDA_ARCH__) - bits = reinterpret_cast(x); - #else - std::memcpy(&bits, &x, sizeof(bits)); - #endif - - if ((bits & 0x7f800000) != 0x7f800000) { - - bool mantissa_bit = ((bits & (1 << 16)) != 0); - bool round_bit = ((bits & (1 << 15)) != 0); - bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); - - if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { - bits += uint32_t(1 << 16); - } - } - else if (bits & ~0xff800000) { - bits = 0x7fffffff; - } - - storage = uint16_t((bits >> 16) & 0xffff); - #endif - } - - /// Floating-point conversion - round toward nearest - CUTLASS_HOST_DEVICE - explicit bfloat16_t(double x): bfloat16_t(float(x)) { - - } - - /// Integer conversion - round toward nearest - CUTLASS_HOST_DEVICE - explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {} - - CUTLASS_HOST_DEVICE - explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {} - - /// Converts to float - CUTLASS_HOST_DEVICE - operator float() const { - unsigned bits = (unsigned(storage) << 16); - #if defined(__CUDA_ARCH__) - return reinterpret_cast(bits); - #else - float flt; - std::memcpy(&flt, &bits, sizeof(flt)); - return flt; - #endif - } - - /// Converts to float - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(float(*this)); - } - - /// Converts to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(float(*this)); - } - - /// Casts to bool - CUTLASS_HOST_DEVICE - explicit operator bool() const { - return (float(*this) != 0.0f); - } - - /// Bitcasts to CUDA's bf16 type - CUTLASS_DEVICE - __nv_bfloat16 to_nv_bfloat16() const { - return reinterpret_cast<__nv_bfloat16 const &>(storage); - } - - /// Obtains raw bits - CUTLASS_HOST_DEVICE - uint16_t raw() const { - return storage; - } - /// Returns the sign bit - CUTLASS_HOST_DEVICE - bool signbit() const { - return ((raw() & 0x8000) != 0); - } - - /// Returns the biased exponent - CUTLASS_HOST_DEVICE - int exponent_biased() const { - return int((raw() >> 7) & 0x0ff); - } - - /// Returns the unbiased exponent - CUTLASS_HOST_DEVICE - int exponent() const { - return exponent_biased() - 127; - } - - /// Returns the mantissa - CUTLASS_HOST_DEVICE - int mantissa() const { - return int(raw() & 0x7f); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool signbit(cutlass::bfloat16_t const& h) { - return h.signbit(); -} - -CUTLASS_HOST_DEVICE -cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { - return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff); -} - -CUTLASS_HOST_DEVICE -bool isnan(cutlass::bfloat16_t const& h) { - return (h.exponent_biased() == 0x0ff) && h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isfinite(cutlass::bfloat16_t const& h) { - return (h.exponent_biased() != 0x0ff); -} - -CUTLASS_HOST_DEVICE -cutlass::bfloat16_t nan_bf16(const char*) { - // NVIDIA canonical NaN - return cutlass::bfloat16_t::bitcast(0x7fff); -} - -CUTLASS_HOST_DEVICE -bool isinf(cutlass::bfloat16_t const& h) { - return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isnormal(cutlass::bfloat16_t const& h) { - return h.exponent_biased() && h.exponent_biased() != 0x0ff; -} - -CUTLASS_HOST_DEVICE -int fpclassify(cutlass::bfloat16_t const& h) { - int exp = h.exponent_biased(); - int mantissa = h.mantissa(); - if (exp == 0x0ff) { - if (mantissa) { - return FP_NAN; - } - else { - return FP_INFINITE; - } - } - else if (!exp) { - if (mantissa) { - return FP_SUBNORMAL; - } - else { - return FP_ZERO; - } - } - return FP_NORMAL; -} - -CUTLASS_HOST_DEVICE -cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { -#if defined(__CUDACC_RTC__) - return cutlass::bfloat16_t(sqrtf(float(h))); -#else - return cutlass::bfloat16_t(std::sqrt(float(h))); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { - - uint16_t a_bits; - uint16_t b_bits; - - #if defined(__CUDA_ARCH__) - a_bits = reinterpret_cast(a); - b_bits = reinterpret_cast(b); - #else - std::memcpy(&a_bits, &a, sizeof(a_bits)); - std::memcpy(&b_bits, &b, sizeof(b_bits)); - #endif - - uint16_t a_mag = (a_bits & 0x7fff); - uint16_t b_sign = (b_bits & 0x8000); - uint16_t result = (a_mag | b_sign); - - return bfloat16_t::bitcast(result); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Standard Library operations and definitions -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#if !defined(__CUDACC_RTC__) -namespace std { - -/// Numeric limits -template <> -struct numeric_limits { - static bool const is_specialized = true; - static bool const is_signed = true; - static bool const is_integer = false; - static bool const is_exact = false; - static bool const has_infinity = true; - static bool const has_quiet_NaN = true; - static bool const has_signaling_NaN = false; - static std::float_denorm_style const has_denorm = std::denorm_present; - static bool const has_denorm_loss = true; - static std::float_round_style const round_style = std::round_to_nearest; - static bool const is_iec559 = false; - static bool const is_bounded = true; - static bool const is_modulo = false; - static int const digits = 7; - - /// Least positive value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } - - /// Minimum finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } - - /// Maximum finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } -}; - -} // namespace std -#endif - -namespace cutlass { -namespace platform { - -/// Forward Declaration -template -struct numeric_limits; - -/// Numeric limits -template <> -struct numeric_limits { - static bool const is_specialized = true; - static bool const is_signed = true; - static bool const is_integer = false; - static bool const is_exact = false; - static bool const has_infinity = true; - static bool const has_quiet_NaN = true; - static bool const has_signaling_NaN = false; -#if !defined(__CUDACC_RTC__) - static std::float_denorm_style const has_denorm = std::denorm_present; -#endif - static bool const has_denorm_loss = true; -#if !defined(__CUDACC_RTC__) - static std::float_round_style const round_style = std::round_to_nearest; -#endif - static bool const is_iec559 = false; - static bool const is_bounded = true; - static bool const is_modulo = false; - static int const digits = 7; - - /// Least positive value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } - - /// Minimum finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } - - /// Maximum finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } -}; - -} // namespace platform -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Arithmetic operators -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) == float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) != float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) < float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) <= float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) > float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); -#else - return float(lhs) >= float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - return bfloat16_t(float(lhs) + float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator-(bfloat16_t const& lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return bfloat16_t(__hneg(lhs.to_nv_bfloat16())); -#else - return bfloat16_t(-float(lhs)); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - return bfloat16_t(float(lhs) - float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - return bfloat16_t(float(lhs) * float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - return bfloat16_t(float(lhs) / float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - lhs = bfloat16_t(float(lhs) + float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - lhs = bfloat16_t(float(lhs) - float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - lhs = bfloat16_t(float(lhs) * float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); -#else - lhs = bfloat16_t(float(lhs) / float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator++(bfloat16_t & lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); -#else - float tmp(lhs); - ++tmp; - lhs = bfloat16_t(tmp); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t& operator--(bfloat16_t & lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); -#else - float tmp(lhs); - --tmp; - lhs = bfloat16_t(tmp); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator++(bfloat16_t & lhs, int) { - bfloat16_t ret(lhs); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); -#else - float tmp(lhs); - tmp++; - lhs = bfloat16_t(tmp); -#endif - return ret; -} - -CUTLASS_HOST_DEVICE -bfloat16_t operator--(bfloat16_t & lhs, int) { - bfloat16_t ret(lhs); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); -#else - float tmp(lhs); - tmp--; - lhs = bfloat16_t(tmp); -#endif - return ret; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// User-defined literals -// - -CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(long double x) { - return cutlass::bfloat16_t(float(x)); -} - -CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { - return cutlass::bfloat16_t(int(x)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h deleted file mode 100644 index 8788f18b99d5c9d700a0f6f28625097f41862c74..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h +++ /dev/null @@ -1,143 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Basic include for CUTLASS BLAS3/HPC code. - - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/blas3_types.h" -#include "cutlass/coord.h" -#include "cutlass/complex.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines FillMode inversions -template -struct InvertFillMode; - -/// Invert FillMode lower to upper -template <> -struct InvertFillMode { - static FillMode const mode = FillMode::kUpper; -}; - -/// Invert FillMode upper to lower -template <> -struct InvertFillMode { - static FillMode const mode = FillMode::kLower; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines SideMode inversions -template -struct InvertSideMode; - -/// Invert SideMode left to right -template <> -struct InvertSideMode { - static SideMode const mode = SideMode::kRight; -}; - -/// Invert SideMode right to left -template <> -struct InvertSideMode { - static SideMode const mode = SideMode::kLeft; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines correct compare operation for Triangular matrix boundary -template -struct TrMatrixCompareOp { - using Index = int32_t; - using Type = typename platform::conditional< - (kFillMode == FillMode::kLower), - greater_equal, - less_equal>::type; -}; - -template -struct TrMatrixCompareOp { - using Index = int32_t; - using Type = typename platform::conditional< - (kFillMode == FillMode::kLower), - greater_equal, - less_equal>::type; -}; - -template -struct TrMatrixCompareOp { - using Index = int32_t; - using Type = typename platform::conditional< - (kFillMode == FillMode::kLower), - greater, - less>::type; -}; -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Returns precision in terms of bits (based on datatype) to fill tensors with. -// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs). -// Also defines acceptable mantissa result variance/error. -template -struct MantissaInBits { - static int constexpr bits = 5; - static double constexpr error = 1.0e-7; -}; - -// Full precision is supported for FP64 -template <> -struct MantissaInBits { - static int constexpr bits = 30; - static double constexpr error = 1.0e-15; -}; - -template <> -struct MantissaInBits> { - static int constexpr bits = 30; - static double constexpr error = 1.0e-14; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h deleted file mode 100644 index e47002b1a7255478f3a8d08518a4e081cbfd2422..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h +++ /dev/null @@ -1,78 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumerated type describing the type of kernel (based on input or output matrices). -enum class BlasMode { - kGemm, - kSymmetric, - kHermitian, - kTriangular, - kInvalid -}; - -/// Enumerated type describing the fill mode for matrices for BLAS functions. -enum class FillMode { - kFull, /// The entire tensor is covered. - kLower, /// The 'lower' part of a tensor is covered including diagonal - kUpper, /// The 'upper' part of a tensor is covered including diaognal - kDiagonal, /// Only diagonal elements are covered. - kNone, /// No element is covered. - kInvalid -}; - -/// Enumerated type describing the diagonal property of matrices for BLAS functions. -enum class DiagType { - kNonUnit, - kUnit, - kZero, // Only used internally for computing SYMM/HEMM - kInvalid -}; - -/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. -enum class SideMode { - kLeft, - kRight, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h deleted file mode 100644 index 93665c64047d847a6fc9de3f5ec691caa8186dbc..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h +++ /dev/null @@ -1,267 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, - statically-sized array types to global memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/wmma_array.h" -#include "cutlass/functional.h" -#include "cutlass/complex.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// AccessWidth -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit -template < - typename T, - int Limit> -struct AccessWidth -{ - // Inductive case - template < - int ObjectBytes, /// Size of T in bytes - int AlignBytes, /// Template induction variable - bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes - ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> - struct Detail - { - static const int value = Detail::value; - }; - - // Base case (ObjectBytes is not an even multiple of AlignBytes) - template < - int ObjectBytes, /// Size of T in bytes - int AlignBytes> /// Template induction variable - struct Detail - { - static const int value = AlignBytes / 2; - }; - - /// The maximal power-of-two that evenly divides the size of T - static const int value = Detail< - (int) sizeof(T), - 1>::value; -}; - - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// StripedAccessType -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// ReinterpretCast type for striping a trivially-copyable type in global memory -/// (Default specialization. Striping granularity is type T.) -template < - typename T, /// Data type - int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) - AccessWidth::value> -struct alignas(TransferBytes) StripedAccessType : public T -{}; - - -/// ReinterpretCast type for striping a trivially-copyable type in global memory -/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) -template < - typename T, /// Array element type - int N, /// Number of elements in array - bool RegisterSized, /// T is register-sized - int TransferBytes> /// Data access width -struct StripedAccessType< - Array, - TransferBytes> -: public AlignedArray< - T, // Element type of StripedAccessType - __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType - TransferBytes> // Alignment of StripedAccessType -{}; - - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -/// ReinterpretCast type for striping a trivially-copyable type in global memory -/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) -template< - typename Use, - int m, - int n, - int k, - typename ElementT, - typename Layout, - int kFragments, - int TransferBytes> -struct StripedAccessType< - WmmaFragmentArray, kFragments>, - TransferBytes> -: public AlignedArray< - ElementT, - __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), - TransferBytes> -{}; - -#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// BlockStriped -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Utility for performing block-striped access (load, store) of trivially-copyable, -/// statically-sized array types to global memory -template < - int BlockThreads, - typename ArrayT, - typename AccessT = StripedAccessType > -struct BlockStriped -{ - /// Number of striped accesses - static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); - static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); - - /// Load - CUTLASS_DEVICE - static void load(ArrayT &data, ArrayT *ptr, int thread_idx) - { - AccessT *access_input = reinterpret_cast(ptr); - AccessT *access_data = reinterpret_cast(&data); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kStripes; ++i) { - access_data[i] = access_input[(BlockThreads * i) + thread_idx]; - } - } - - /// Load & Add - CUTLASS_DEVICE - static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) - { - AccessT *access_input = reinterpret_cast(ptr); - AccessT *access_data = reinterpret_cast(&data); - - plus add; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kStripes; ++i) - { - access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); - } - } - - /// Store - CUTLASS_DEVICE - static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) - { - AccessT *access_output = reinterpret_cast(ptr); - const AccessT *access_data = reinterpret_cast(&data); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kStripes; ++i) { - access_output[(BlockThreads * i) + thread_idx] = access_data[i]; - } - } - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// BlockStripedReduce -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -/// statically-sized array types to global memory. -/// (Default specialization) -template < - int BlockThreads, - typename ArrayT, - typename ElementT = typename StripedAccessType::Element> -struct BlockStripedReduce : - BlockStriped< - BlockThreads, - ArrayT, - ElementT> -{ - /// Reduce - CUTLASS_DEVICE - static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) - { - cutlass::atomic_add reduce; - ElementT *access_output = reinterpret_cast(ptr); - const ElementT *access_data = reinterpret_cast(&data); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { - reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); - } - } -}; - - -/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -/// statically-sized array types to global memory. -/// (Specialization for half_t. Uses half2 vectorized-reduction.) -template < - int BlockThreads, - typename ArrayT> -struct BlockStripedReduce : - BlockStriped< - BlockThreads, - ArrayT, - half2> -{ - static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); - - /// Reduce - CUTLASS_DEVICE - static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) - { - cutlass::atomic_add reduce; - half2 *access_output = reinterpret_cast(ptr); - const half2 *access_data = reinterpret_cast(&data); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < BlockStripedReduce::kStripes; ++i) - { - reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); - } - } -}; - - -} // namespace cutlass - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp deleted file mode 100644 index 22c17dba702f62eeab80ab5b3399bda269f4f4d2..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp +++ /dev/null @@ -1,394 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief CUDA interfaces to launch CUTLASS device-level operators (for >= SM90) that use thread-block clusters. -*/ - -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/trace.h" -#include -#include "cutlass/arch/synclog.hpp" - -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(type_traits) -#else -#include -#include -#endif - -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) -# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED -#endif - -#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) - # define CUDA_ENABLE_PREFERRED_CLUSTER -#endif -namespace cutlass { - -#ifndef NDEBUG -#define Return_Status(cudaError_t_status) \ - if (cudaError_t_status != cudaSuccess) { \ - fprintf(stderr, \ - "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString(cudaError_t_status)); \ - return Status::kInvalid; \ - } else { \ - return Status::kSuccess; \ - } -#else -#define Return_Status(cudaError_t_status) \ - if (cudaError_t_status != cudaSuccess) { \ - return Status::kInvalid; \ - } else { \ - return Status::kSuccess; \ - } -#endif - -struct ClusterLauncher { - constexpr static int MaxClusterSize = 32; - - struct LaunchConfig { -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - cudaLaunchConfig_t launch_config; - - #if defined(CUDA_ENABLE_PREFERRED_CLUSTER) - constexpr static int numAttrs = 3; - #else - - constexpr static int numAttrs = 2; - #endif - cudaLaunchAttribute launch_attribute[numAttrs]; - // Commonly used utility functions - dim3 gridDim() { return launch_config.gridDim; } - dim3 blockDim() { return launch_config.blockDim; } -#endif - }; - - // Check for hardware compatibility - static inline CUTLASS_HOST - Status check_cluster_dims(dim3 grid, dim3 cluster) { - if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && - (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { - return Status::kSuccess; - } - else { - CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); - return Status::kInvalid; - } - } - - static inline CUTLASS_HOST - Status -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - init(void const* kernel_function) -#else - init(void const* /* kernel_function */) -#endif - { -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - if (kernel_function == nullptr) { - CUTLASS_TRACE_HOST("kernel_function is null"); - return Status::kInvalid; - } - CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute"); - cudaError_t prevStatus = cudaGetLastError(); - if (prevStatus != cudaSuccess) { - fprintf(stderr, - "[ ERROR: CUDA Runtime ] %s:%d: %s\n", - __FILE__, - __LINE__, - cudaGetErrorString(prevStatus)); - return Status::kInvalid; - } - CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute"); -#endif - // This attribute was added in CUDA 11.8. - cudaError_t status = - cudaFuncSetAttribute( - kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); - Return_Status(status); -#else - return Status::kInvalid; -#endif - } - - static inline CUTLASS_HOST - LaunchConfig make_cluster_launch_config( - dim3 const grid_dims, - dim3 const cluster_dims, - dim3 const block_dims, - size_t const smem_size = 0, - cudaStream_t cuda_stream = 0, - bool launch_with_pdl = false - , dim3 const fallback_cluster_dims = {0, 0, 0} - ) { - LaunchConfig cluster_launch_config; -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - auto &launch_config = cluster_launch_config.launch_config; - auto &launch_attribute = cluster_launch_config.launch_attribute; - auto numAttrs = cluster_launch_config.numAttrs; - - launch_attribute[0].id = cudaLaunchAttributeClusterDimension; - - bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0; - - if (have_fallback) { - launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z}; - CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = " - "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); - } - else { - - launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; - CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = " - "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); - - } - -#if defined(CUDA_ENABLE_PREFERRED_CLUSTER) - if (have_fallback) { - if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) { - launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension; - launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; - CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = " - "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); - } - } - else { - numAttrs--; - } -#endif - - - // PDL attributes - launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization; - launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1; - - launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; - launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; - launch_config.dynamicSmemBytes = smem_size; - launch_config.stream = cuda_stream; - launch_config.numAttrs = launch_with_pdl ? numAttrs : numAttrs - 1; - launch_config.attrs = launch_attribute; - return cluster_launch_config; -#else - CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); - return cluster_launch_config; -#endif - } - - // This is the method we expect to use going forward - static inline CUTLASS_HOST - Status launch( - dim3 const grid_dims, - dim3 const cluster_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - void const* kernel, - void** kernel_params, - bool launch_with_pdl = false) { -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, cluster_dims, - block_dims, smem_size, cuda_stream, launch_with_pdl); - - auto launch_grid_dims = cluster_launch_config.gridDim(); - if (check_cluster_dims(launch_grid_dims, cluster_dims) != Status::kSuccess) { - CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); - return Status::kInvalid; - } - - auto init_status = init(kernel); - if (init_status != Status::kSuccess) { - CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); - return Status::kInvalid; - } - - CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = " - "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " - "And ClusterDims = " - "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); - - cutlass::arch::synclog_setup(); - cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params); - Return_Status(status); -#else - CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); - return Status::kInvalid; -#endif - } - - - // This is the method we expect to use going forward - // Launch a preferred cluster grid - static inline CUTLASS_HOST - Status launch_with_fallback_cluster( - dim3 const grid_dims, - dim3 const preferred_cluster_dims, - dim3 const fallback_cluster_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - void const* kernel, - void** kernel_params, - bool launch_with_pdl = false) { -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims, - block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims); - - auto launch_grid_dims = cluster_launch_config.gridDim(); - if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) { - CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); - return Status::kInvalid; - } - - auto init_status = init(kernel); - if (init_status != Status::kSuccess) { - CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); - return Status::kInvalid; - } - - CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = " - "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " - "\n\tPreferred ClusterDims = " - "(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << ")," - "\n\tFallback ClusterDims = " - "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); - - cutlass::arch::synclog_setup(); - cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params); - Return_Status(status); -#else - CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); - return Status::kInvalid; -#endif - } - - -}; - -namespace detail { - -template -void* checked_addressof(Arg&& arg) { - static_assert(! std::is_rvalue_reference_v || ! std::is_const_v, "You cannot take the address of a const rvalue reference (const T&&)."); - // We use std::addressof to ensure we get the address, - // in case the type has an overloaded operator&. - // Note that this precludes `const T&&` references. - return const_cast(reinterpret_cast(std::addressof(arg))); -} - -} // namespace detail - -//! Parameters for launch_on_cluster (see below). -struct ClusterLaunchParams { - //! Grid dimensions - dim3 grid_dims{1, 1, 1}; - - //! Block dimensions - dim3 block_dims{1, 1, 1}; - - //! Cluster dimensions - dim3 cluster_dims{1, 1, 1}; - - //! Number of bytes required for the kernel's shared memory. - int smem_size_in_bytes = 0; - - //! CUDA stream on which to launch the kernel. - cudaStream_t cuda_stream = nullptr; -}; - -/// @brief Launch the kernel on the stream using cluster launch. -/// -/// @param params Cluster launch parameters (see above). -/// @param kernel_ptr Pointer to the kernel function (see example). -/// @param args Zero or more arguments to pass to the kernel. -/// -/// @tparam Args Types of the arguments passed to the kernel. -/// Don't specify this/these template argument(s) explicitly. -/// -/// @return Status::Success on success, else an error code. -/// -/// @code -/// template -/// __global__ void kernel(A a, B b, C c); -/// -/// X x = get_x(); -/// Y y = get_y(); -/// Z z = get_z(); -/// -/// void const* kernel_ptr = -/// const_cast(reinterpret_cast( -/// &kernel)); -/// auto status = launch_kernel_on_cluster( -/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, -/// kernel_ptr, x, y, z); -/// @endcode -template -CUTLASS_HOST cutlass::Status -launch_kernel_on_cluster(const ClusterLaunchParams& params, - void const* kernel_ptr, - Args&& ... args) -{ - // Unfortunately, we find ourselves needing to pass in - // the parameters as an array of raw pointers. - if constexpr (sizeof...(Args) == 0) { - return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, - params.block_dims, - params.smem_size_in_bytes, - params.cuda_stream, - kernel_ptr, nullptr); - } - else { - void* kernel_params[sizeof...(Args)] = { - detail::checked_addressof(std::forward(args))... - }; - return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, - params.block_dims, - params.smem_size_in_bytes, - params.cuda_stream, - kernel_ptr, - kernel_params); - } -} - -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h deleted file mode 100644 index 0287850bc6febe16a90695c82fabee566cdf9a82..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h +++ /dev/null @@ -1,821 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include - -#include -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#endif -#include "cutlass/functional.h" -#include "cutlass/platform/platform.h" -#include "cutlass/real.h" - -#include "cutlass/numeric_types.h" - -#include "cutlass/fast_math.h" - -#if !defined(__CUDACC_RTC__) -#include -#endif - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Enumeraed type describing a transformation on a complex value. -enum class ComplexTransform { - kNone, - kConjugate -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines ComplexTransform inversions -template -struct InvertComplexTransform; - -/// Invert ComplexTransform from kNone to kConjugate -template <> -struct InvertComplexTransform { - static ComplexTransform const transform = ComplexTransform::kConjugate; -}; - -/// Invert ComplexTransform from kConjugate to kNone -template <> -struct InvertComplexTransform { - static ComplexTransform const transform = ComplexTransform::kNone; -}; -///////////////////////////////////////////////////////////////////////////////////////////////// -////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Accessors for CUDA complex types -// - -#if !defined(__CUDACC_RTC__) -/// Returns the real part of the complex number -CUTLASS_HOST_DEVICE -float const &real(cuFloatComplex const &z) { return z.x; } - -/// Returns the real part of the complex number -CUTLASS_HOST_DEVICE -float &real(cuFloatComplex &z) { return z.x; } - -/// Returns the real part of the complex number -CUTLASS_HOST_DEVICE -double const &real(cuDoubleComplex const &z) { return z.x; } - -/// Returns the real part of the complex number -CUTLASS_HOST_DEVICE -double &real(cuDoubleComplex &z) { return z.x; } - -/// Returns the imaginary part of the complex number -CUTLASS_HOST_DEVICE -float const &imag(cuFloatComplex const &z) { return z.y; } - -/// Returns the imaginary part of the complex number -CUTLASS_HOST_DEVICE -float &imag(cuFloatComplex &z) { return z.y; } - -/// Returns the imaginary part of the complex number -CUTLASS_HOST_DEVICE -double const &imag(cuDoubleComplex const &z) { return z.y; } - -/// Returns the imaginary part of the complex number -CUTLASS_HOST_DEVICE -double &imag(cuDoubleComplex &z) { return z.y; } - -// Returns the conjugate of the complex number -CUTLASS_HOST_DEVICE cuFloatComplex -conj(cuFloatComplex const& z) { - return make_cuFloatComplex(z.x, -z.y); -} - -// Returns the conjugate of the complex number -CUTLASS_HOST_DEVICE cuDoubleComplex -conj(cuDoubleComplex const& z) { - return make_cuDoubleComplex(z.x, -z.y); -} -#endif - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Class for representing and manipulating complex numbers with conversions from built-in CUDA -/// complex types. - -template -class complex -{ - public: - /// Type alias for scalar type - using value_type = T; - - private: - // - // Data members - // - - /// Real part - T _real; - - /// Imaginary part - T _imag; - - public: - -// -// Methods -// - - /// Default constructor - complex() = default; - - /// Constructor - CUTLASS_HOST_DEVICE - complex(T r) : _real(r), _imag(T(0)) {} - - /// Constructor - CUTLASS_HOST_DEVICE - complex(T r, T i) : _real(r), _imag(i) {} - - /// Constructor - template - CUTLASS_HOST_DEVICE - complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} - - - #if !defined(__CUDACC_RTC__) - /// Conversion from cuFloatComplex - CUTLASS_HOST_DEVICE - complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} - - /// Conversion from cuDoubleComplex - CUTLASS_HOST_DEVICE - complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} - #endif - - /// Equality operator - CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { - return this->real() == rhs.real() && this->imag() == rhs.imag(); - } - - /// Inequality operator - CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { - return !(*this == rhs); - } - - /// Addition - template - CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { - return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); - } - - /// Reduction into memory address. Components may update out of order. - template - CUTLASS_DEVICE void red(complex *ptr) const { - static_assert(platform::is_same::value, "Component type must match"); - cutlass::atomic_add reduce; - reduce(&ptr->_real, _real); - reduce(&ptr->_imag, _imag); - } - - /// Reduction into memory address. Components may update out of order. (Half specialization) - CUTLASS_DEVICE void red(complex *ptr) const { - static_assert(platform::is_same::value, "Component type must match"); - half2 *h2_ptr = reinterpret_cast(ptr); - half2 h2_data = reinterpret_cast(*this); - cutlass::atomic_add reduce; - reduce(h2_ptr, h2_data); - } - - /// Subtraction - template - CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { - return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); - } - - /// Multiplication - template - CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { - return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), - this->real() * rhs.imag() + this->imag() * rhs.real()); - } - - /// Scalar Multiplication - template - CUTLASS_HOST_DEVICE complex operator*(A const &s) const { - return complex(this->real() * s, this->imag() * s); - } - - /// Division - template - CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { - T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); - - return complex( - (real() * rhs.real() + imag() * rhs.imag()) / d, - (imag() * rhs.real() - real() * rhs.imag()) / d - ); - } - - /// Scalar Division - template - CUTLASS_HOST_DEVICE complex operator/(A const &s) const { - return complex(this->real() / s, this->imag() / s); - } - - /// Addition - template - CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { - *this = *this + rhs; - return *this; - } - - /// Subtraction - template - CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { - *this = *this - rhs; - return *this; - } - - /// Multiplication - template - CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { - *this = *this * rhs; - return *this; - } - - /// Scalar multiplication - template - CUTLASS_HOST_DEVICE complex &operator*=(A s) { - *this = *this * s; - return *this; - } - - /// Division - template - CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { - *this = *this / rhs; - return *this; - } - - /// Accesses the real part of the complex number - CUTLASS_HOST_DEVICE - T const &real() const { return _real; } - - /// Accesses the real part of the complex number - CUTLASS_HOST_DEVICE - T &real() { return _real; } - - /// Accesses the imaginary part of the complex number - CUTLASS_HOST_DEVICE - T const &imag() const { return _imag; } - - /// Accesses the imaginary part of the complex number - CUTLASS_HOST_DEVICE - T &imag() { return _imag; } - - /// Set the real part of the complex number - CUTLASS_HOST_DEVICE - void real(T real) { _real = real; } - - /// Set the imaginary part of the complex number - CUTLASS_HOST_DEVICE - void imag(T imag) { _imag = imag; } - - #if !defined(__CUDACC_RTC__) - /// Converts to cuFloatComplex - CUTLASS_HOST_DEVICE - explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } - - /// Converts to cuDoubleComplex - CUTLASS_HOST_DEVICE - explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } - #endif -}; - -// Complex conjugate -template -CUTLASS_HOST_DEVICE complex conj(complex const& z) { - return {z.real(), -z.imag()}; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Accessors for complex template -// - -// Nonmember real and imag need to work for non-complex numbers too. -// That means cutlass::complex, std::complex, cuda::std::complex, and -// any user-defined complex number type that looks like std::complex. -// It's reasonable to assume that a "complex number type" has -// zero-argument real() and imag() member functions returning -// non-void. While cuFloatComplex and cuDoubleComplex lack those -// member functions, one-argument nonmember real and imag overloads -// for those types are defined above. - -namespace detail { - -template -struct has_zero_argument_real_member_function : - cutlass::platform::false_type -{}; - -template -struct has_zero_argument_real_member_function().real()) - > - > -> : cutlass::platform::true_type -{}; - -template -constexpr bool has_zero_argument_real_member_function_v = - has_zero_argument_real_member_function::value; - -template -struct has_zero_argument_imag_member_function : - cutlass::platform::false_type -{}; - -template -struct has_zero_argument_imag_member_function().imag()) - > - > -> : cutlass::platform::true_type -{}; - -template -constexpr bool has_zero_argument_imag_member_function_v = - has_zero_argument_imag_member_function::value; - -} // namespace detail - -template -CUTLASS_HOST_DEVICE auto real(T z) { - if constexpr (detail::has_zero_argument_real_member_function_v) { - return z.real(); - } else { - return z; - } -} - -template -CUTLASS_HOST_DEVICE auto imag(T z) { - if constexpr (detail::has_zero_argument_imag_member_function_v) { - return z.imag(); - } else { - // Imaginary part of a non-complex input has the same type as the - // input, and its value is zero. CUTLASS assumes in this case - // that value-initializing T is well-formed and results in zero. - return T{}; - } -} - -// -// Output operators -// - -#if !defined(__CUDACC_RTC__) -template -std::ostream &operator<<(std::ostream &out, complex const &z) { - T _r = real(z); - T _i = imag(z); - - if (bool(_i)) { - return out << _r << "+i" << _i; - } - return out << _r; -} -#endif - -// -// Non-member operators defined for complex types -// - - -// -// Non-member functions defined for complex numbers -// - -// abs returns the magnitude of the complex number. - -CUTLASS_HOST_DEVICE float abs(complex const &z) { - return ::hypot(z.real(), z.imag()); -} - -CUTLASS_HOST_DEVICE double abs(complex const &z) { - return ::hypot(z.real(), z.imag()); -} - -// In theory, it would make sense to add a complex -// specialization of abs here, since hypot works for long double too. -// In practice, long double doesn't have a portable number of bits or -// behavior, so users who care about higher-precision floating-point -// computation should probably insist on an actual FP128 type. - -template -CUTLASS_HOST_DEVICE T abs(complex const &z) { - // cutlass::complex permits all kinds of T, including types that - // don't have NaN. For a generic floating-point type with Inf - // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it - // would handle issues like avoiding unwarranted overflow if - // z.real() or z.imag() is slightly bigger than the square root of - // the max finite number. That could be a future improvement; for - // now, the code just uses the naive algorithm. - // - // Use the "swap two-step" idiom so that argument-dependent lookup - // can find any CUTLASS-specific overloads. - using cutlass::sqrt; - return sqrt(z.real() * z.real() + z.imag() * z.imag()); -} - -/// Returns the magnitude of the complex number -template -CUTLASS_HOST_DEVICE T arg(complex const &z) { - return atan2(imag(z), real(z)); -} - -/// Returns the squared magnitude of a real number -template -CUTLASS_HOST_DEVICE T norm(T const &z) { - return z * z; -} - -/// Returns the squared magnitude of a real number -template <> -CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { - return static_cast(z * z); -} - -/// Returns the squared magnitude of a complex number -template -CUTLASS_HOST_DEVICE double norm(complex const &z) { - return real(z) * real(z) + imag(z) * imag(z); -} - -/// Norm-accumulate calculation -template -CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { - return accumulator + static_cast(x) * static_cast(x); -} - -/// Norm accumulate specialized for complex types -template -CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { - return accumulator + static_cast(real(z)) * static_cast(real(z)) + - static_cast(imag(z)) * static_cast(imag(z)); -} - -namespace detail { - -template -CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) { - return conj(z); -} - -template -CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) { - return z; -} - -template -CUTLASS_HOST_DEVICE T conj_impl(T const& z) { - constexpr bool use_unqualified_conj = - ! cutlass::platform::is_arithmetic_v && - ! detail::has_cutlass_conj_v && - detail::has_unqualified_conj_v; - return conj_impl(z, cutlass::platform::bool_constant{}); -} - -} // namespace detail - -// Return the complex conjugate of the input. -// -// This MUST be a function and not a function object, because it may -// be common practice for downstream types to define specifically -// cutlass::conj overloads, instead of overloads in their namespace. -// -// As a result of this being a function and not a function object, -// CUTLASS code needs to declare "using cutlass::conj;" in scope and -// then call this function unqualified, just like std::swap. -// -// If an overload already exists for cutlass::conj(T), that overload -// will be called instead of this one. Otherwise: -// -// 1. for arithmetic types, return z; -// -// 2. for types where (namespace-unqualified) conj(z) is well formed -// and cutlass::conj(z) is NOT well formed, return conj(z); and, -// -// 3. for everything else, return z. -// -// Regarding (1), the C++ Standard Library makes std::conj always -// return std::complex, even for (noncomplex) arithmetic types. -// cutlass::conj(T t) needs to return type T. This follows the -// convention of linear algebra software like the BLAS, where -// "conjugate transpose" means the same thing as "transpose" for a -// matrix of noncomplex numbers. -// -// Case (2) covers std::complex, cuda::std::complex, and non-Standard -// (including user-defined) complex number types (for which "conj(z)" -// is findable via argument-dependent lookup, but does not live in the -// cutlass namespace). It excludes cutlass::conj(z) in order to -// prevent infinite recursion. -// -// Case (3) covers non-Standard non-complex number types. -template -CUTLASS_HOST_DEVICE T conj(T const& z) { - return detail::conj_impl(z); -} - -/// Projects the complex number z onto the Riemann sphere -template -CUTLASS_HOST_DEVICE complex proj(complex const &z) { - T d = real(z) * real(z) + imag(z) * imag(z) + T(1); - return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); -} - -/// Returns a complex number with magnitude r and phase theta -template -CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { - return complex(r * cos(theta), r * sin(theta)); -} - -/// Computes the complex exponential of z. -template -CUTLASS_HOST_DEVICE complex exp(complex const &z) { - return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); -} - -/// Computes the log of z -template -CUTLASS_HOST_DEVICE complex log(complex const &z) { - return complex(log(abs(z)), arg(z)); -} - -/// Computes the log base 10 of z -template -CUTLASS_HOST_DEVICE complex log10(complex const &z) { - return log(z) / T(log(T(10))); -} - -/// Computes the square root of complex number z -template -CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { - return sqrt(T(2)) / T(2) * - complex(sqrt(sqrt(norm(z)) + real(z)), - (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); -} - -/// Computes the cosine of complex z. -template -CUTLASS_HOST_DEVICE complex cos(complex const &z) { - return (exp(z) + exp(-z)) / T(2); -} - -/// Computes the sin of complex z. -template -CUTLASS_HOST_DEVICE complex sin(complex const &z) { - return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); -} - -/// Comparison -template -CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { - return true; -} - -////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex-valued type. -template -struct RealType< complex > -{ - using Type = T; - - /// Number of elements - static int const kExtent = 2; - - CUTLASS_HOST_DEVICE - static complex from_real(double x) { - return complex(static_cast(x)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_HOST_DEVICE -cutlass::complex from_real >(double r) { - return cutlass::complex(half_t(r)); -} - -template <> -CUTLASS_HOST_DEVICE -cutlass::complex from_real >(double r) { - return cutlass::complex(float(r)); -} - -template <> -CUTLASS_HOST_DEVICE -cutlass::complex from_real >(double r) { - return cutlass::complex(r); -} - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct is_complex { - static bool const value = false; -}; - -template -struct is_complex> { - static bool const value = true; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// functional.h numeric specializations -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Squares with optional conversion -template -struct magnitude_squared, Output> { - CUTLASS_HOST_DEVICE - Output operator()(complex lhs) const { - multiplies mul_op; - - Output y_r = Output(lhs.real()); - Output y_i = Output(lhs.imag()); - - return mul_op(y_r, y_r) + mul_op(y_i, y_i); - } -}; - -/// Fused multiply-add -template -struct multiply_add, complex, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - complex const &a, - complex const &b, - complex const &c) const { - - T real = c.real(); - T imag = c.imag(); - - real += a.real() * b.real(); - real += -a.imag() * b.imag(); - imag += a.real() * b.imag(); - imag += a.imag () * b.real(); - - return complex{ - real, - imag - }; - } -}; - -/// Fused multiply-add -template -struct multiply_add, T, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - complex const &a, - T const &b, - complex const &c) const { - - T real = c.real(); - T imag = c.imag(); - - real += a.real() * b; - imag += a.imag () * b; - - return complex{ - real, - imag - }; - } -}; - -/// Fused multiply-add -template -struct multiply_add, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - T const &a, - complex const &b, - complex const &c) const { - - T real = c.real(); - T imag = c.imag(); - - real += a * b.real(); - imag += a * b.imag(); - - return complex{ - real, - imag - }; - } -}; - -/// Conjugate -template -struct conjugate> { - CUTLASS_HOST_DEVICE - complex operator()(complex const &a) const { - // Invoke the complex overload specifically, rather than - // wasting the compiler's effort on overload resolution. - return cutlass::conj(a); - } -}; - -#if ! defined(__CUDACC_RTC__) -template <> -struct conjugate { - CUTLASS_HOST_DEVICE - cuFloatComplex operator()(cuFloatComplex const& z) const { - return make_cuFloatComplex(z.x, -z.y); - } -}; - -template <> -struct conjugate { - CUTLASS_HOST_DEVICE - cuDoubleComplex operator()(cuDoubleComplex const& z) const { - return make_cuDoubleComplex(z.x, -z.y); - } -}; -#endif - -/// Computes the square of a difference with optional conversion -template -struct magnitude_squared_difference, Output> { - CUTLASS_HOST_DEVICE - Output operator()(complex lhs, complex rhs) const { - multiplies mul_op; - - Output y_r = Output(lhs.real()) - Output(rhs.real()); - Output y_i = Output(lhs.imag()) - Output(rhs.imag()); - - return mul_op(y_r, y_r) + mul_op(y_i, y_i); - } -}; - -/// Reduces value into the data pointed to by ptr (complex specialization) -template -struct atomic_add> { - CUTLASS_DEVICE - void operator()(complex *ptr, const complex &data) - { - data.red(ptr); - } -}; - - -////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h deleted file mode 100644 index f5df01726b3f4dbc88bf2fd6f15092cff2b55fac..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h +++ /dev/null @@ -1,1239 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* \file - \brief Boost-style constant definitions for floating-point types. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/complex.h" - -/////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace constants { - -/////////////////////////////////////////////////////////////////////////////////// - -// -// Primary templates -// - -/// Returns 1, the multiplicative identity element -template CUTLASS_HOST_DEVICE T one(); - -/// Returns 0, the additive identity element -template CUTLASS_HOST_DEVICE T zero(); - -/// Returns 2 -template CUTLASS_HOST_DEVICE T two(); - -/// Returns pi, approximately 3.141 -template CUTLASS_HOST_DEVICE T pi(); - -/// Returns 2 * pi -template CUTLASS_HOST_DEVICE T two_pi(); - -/// Returns pi / 2 -template CUTLASS_HOST_DEVICE T half_pi(); - -/// Returns sqrt(pi) -template CUTLASS_HOST_DEVICE T root_pi(); - -/// Returns sqrt(pi / 2) -template CUTLASS_HOST_DEVICE T root_half_pi(); - -/// Returns sqrt(2 * pi) -template CUTLASS_HOST_DEVICE T root_two_pi(); - -/// Returns sqrt(ln(4)) -template CUTLASS_HOST_DEVICE T root_ln_four(); - -/// Returns e, approximately 2.718... -template CUTLASS_HOST_DEVICE T e(); - -/// Returns (1/2) -template CUTLASS_HOST_DEVICE T half(); - -/// Returns sqrt(2), approximately 1.414... -template CUTLASS_HOST_DEVICE T root_two(); - -/// Returns sqrt(2)/2, approximately 0.707... -template CUTLASS_HOST_DEVICE T half_root_two(); - -/// Returns ln(2), approximately 0.693... -template CUTLASS_HOST_DEVICE T ln_two(); - -/// Returns ln(ln(2)), approximately -0.3665... -template CUTLASS_HOST_DEVICE T ln_ln_two(); - -/// Returns 1/3, approximately 0.333... -template CUTLASS_HOST_DEVICE T third(); - -/// Returns 2/3, approximately 0.666... -template CUTLASS_HOST_DEVICE T twothirds(); - -/// Returns pi - 3, approximately 0.1416... -template CUTLASS_HOST_DEVICE T pi_minus_three(); - -/// Returns 4 - pi, approximately 0.858... -template CUTLASS_HOST_DEVICE T four_minus_pi(); - - -///////////////////////////////////////////////////////////////////////////////////// - -// Specialization for double - -/// Returns 1, the multiplicative identity element (specialization for double) -template <> CUTLASS_HOST_DEVICE double one() { - uint64_t bits = 0x3ff0000000000000ull; - return reinterpret_cast(bits); -} - -/// Returns 1, the multiplicative identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex one< complex >() { - return complex(one(), double()); -} - -/// Returns 0, the additive identity element (specialization for double) -template <> CUTLASS_HOST_DEVICE double zero() { - uint64_t bits = 0x0ull; - return reinterpret_cast(bits); -} - -/// Returns 0, the additive identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex zero< complex >() { - return complex(zero(), double()); -} - -/// Returns 2 (specialization for double) -template <> CUTLASS_HOST_DEVICE double two() { - uint64_t bits = 0x4000000000000000ull; - return reinterpret_cast(bits); -} - -/// Returns 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two< complex >() { - return complex(two(), double()); -} - -/// Returns pi, approximately 3.141 (specialization for double) -template <> CUTLASS_HOST_DEVICE double pi() { - uint64_t bits = 0x400921fb54442d18ull; - return reinterpret_cast(bits); -} - -/// Returns pi, approximately 3.141 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi< complex >() { - return complex(pi(), double()); -} - -/// Returns 2 * pi (specialization for double) -template <> CUTLASS_HOST_DEVICE double two_pi() { - uint64_t bits = 0x401921fb54442d18ull; - return reinterpret_cast(bits); -} - -/// Returns 2 * pi (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { - return complex(two_pi(), double()); -} - -/// Returns pi / 2 (specialization for double) -template <> CUTLASS_HOST_DEVICE double half_pi() { - uint64_t bits = 0x3ff921fb54442d18ull; - return reinterpret_cast(bits); -} - -/// Returns pi / 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { - return complex(half_pi(), double()); -} - -/// Returns sqrt(pi) (specialization for double) -template <> CUTLASS_HOST_DEVICE double root_pi() { - uint64_t bits = 0x3ffc5bf891b4ef6aull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { - return complex(root_pi(), double()); -} - -/// Returns sqrt(pi / 2) (specialization for double) -template <> CUTLASS_HOST_DEVICE double root_half_pi() { - uint64_t bits = 0x3ff40d931ff62705ull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi / 2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { - return complex(root_half_pi(), double()); -} - -/// Returns sqrt(2 * pi) (specialization for double) -template <> CUTLASS_HOST_DEVICE double root_two_pi() { - uint64_t bits = 0x40040d931ff62705ull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2 * pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { - return complex(root_two_pi(), double()); -} - -/// Returns sqrt(ln(4)) (specialization for double) -template <> CUTLASS_HOST_DEVICE double root_ln_four() { - uint64_t bits = 0x3ff2d6abe44afc43ull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(ln(4)) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { - return complex(root_ln_four(), double()); -} - -/// Returns e, approximately 2.718... (specialization for double) -template <> CUTLASS_HOST_DEVICE double e() { - uint64_t bits = 0x4005bf0a8b145769ull; - return reinterpret_cast(bits); -} - -/// Returns e, approximately 2.718... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex e< complex >() { - return complex(e(), double()); -} - -/// Returns (1/2) (specialization for double) -template <> CUTLASS_HOST_DEVICE double half() { - uint64_t bits = 0x3fe0000000000000ull; - return reinterpret_cast(bits); -} - -/// Returns (1/2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half< complex >() { - return complex(half(), double()); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for double) -template <> CUTLASS_HOST_DEVICE double root_two() { - uint64_t bits = 0x3ff6a09e667f3bcdull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { - return complex(root_two(), double()); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) -template <> CUTLASS_HOST_DEVICE double half_root_two() { - uint64_t bits = 0x3fe6a09e667f3bcdull; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { - return complex(half_root_two(), double()); -} - -/// Returns ln(2), approximately 0.693... (specialization for double) -template <> CUTLASS_HOST_DEVICE double ln_two() { - uint64_t bits = 0x3fe62e42fefa39efull; - return reinterpret_cast(bits); -} - -/// Returns ln(2), approximately 0.693... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { - return complex(ln_two(), double()); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) -template <> CUTLASS_HOST_DEVICE double ln_ln_two() { - uint64_t bits = 0xbfd774f29bdd6b9full; - return reinterpret_cast(bits); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { - return complex(ln_ln_two(), double()); -} - -/// Returns 1/3, approximately 0.333... (specialization for double) -template <> CUTLASS_HOST_DEVICE double third() { - uint64_t bits = 0x3fd5555555555555ull; - return reinterpret_cast(bits); -} - -/// Returns 1/3, approximately 0.333... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex third< complex >() { - return complex(third(), double()); -} - -/// Returns 2/3, approximately 0.666... (specialization for double) -template <> CUTLASS_HOST_DEVICE double twothirds() { - uint64_t bits = 0x3fe5555555555555ull; - return reinterpret_cast(bits); -} - -/// Returns 2/3, approximately 0.666... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { - return complex(twothirds(), double()); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for double) -template <> CUTLASS_HOST_DEVICE double pi_minus_three() { - uint64_t bits = 0x3fc21fb54442d180ull; - return reinterpret_cast(bits); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { - return complex(pi_minus_three(), double()); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for double) -template <> CUTLASS_HOST_DEVICE double four_minus_pi() { - uint64_t bits = 0x3feb7812aeef4ba0ull; - return reinterpret_cast(bits); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { - return complex(four_minus_pi(), double()); -} - -///////////////////////////////////////////////////////////////////////////////////// - -// Specialization for float - -/// Returns 1, the multiplicative identity element (specialization for float) -template <> CUTLASS_HOST_DEVICE float one() { - uint32_t bits = 0x3f800000u; - return reinterpret_cast(bits); -} - -/// Returns 1, the multiplicative identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex one< complex >() { - return complex(one(), float()); -} - -/// Returns 0, the additive identity element (specialization for float) -template <> CUTLASS_HOST_DEVICE float zero() { - uint32_t bits = 0x0u; - return reinterpret_cast(bits); -} - -/// Returns 0, the additive identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex zero< complex >() { - return complex(zero(), float()); -} - -/// Returns 2 (specialization for float) -template <> CUTLASS_HOST_DEVICE float two() { - uint32_t bits = 0x40000000u; - return reinterpret_cast(bits); -} - -/// Returns 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two< complex >() { - return complex(two(), float()); -} - -/// Returns pi, approximately 3.141 (specialization for float) -template <> CUTLASS_HOST_DEVICE float pi() { - uint32_t bits = 0x40490fdbu; - return reinterpret_cast(bits); -} - -/// Returns pi, approximately 3.141 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi< complex >() { - return complex(pi(), float()); -} - -/// Returns 2 * pi (specialization for float) -template <> CUTLASS_HOST_DEVICE float two_pi() { - uint32_t bits = 0x40c90fdbu; - return reinterpret_cast(bits); -} - -/// Returns 2 * pi (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { - return complex(two_pi(), float()); -} - -/// Returns pi / 2 (specialization for float) -template <> CUTLASS_HOST_DEVICE float half_pi() { - uint32_t bits = 0x3fc90fdbu; - return reinterpret_cast(bits); -} - -/// Returns pi / 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { - return complex(half_pi(), float()); -} - -/// Returns sqrt(pi) (specialization for float) -template <> CUTLASS_HOST_DEVICE float root_pi() { - uint32_t bits = 0x3fe2dfc5u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { - return complex(root_pi(), float()); -} - -/// Returns sqrt(pi / 2) (specialization for float) -template <> CUTLASS_HOST_DEVICE float root_half_pi() { - uint32_t bits = 0x3fa06c99u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi / 2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { - return complex(root_half_pi(), float()); -} - -/// Returns sqrt(2 * pi) (specialization for float) -template <> CUTLASS_HOST_DEVICE float root_two_pi() { - uint32_t bits = 0x40206c99u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2 * pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { - return complex(root_two_pi(), float()); -} - -/// Returns sqrt(ln(4)) (specialization for float) -template <> CUTLASS_HOST_DEVICE float root_ln_four() { - uint32_t bits = 0x3f96b55fu; - return reinterpret_cast(bits); -} - -/// Returns sqrt(ln(4)) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { - return complex(root_ln_four(), float()); -} - -/// Returns e, approximately 2.718... (specialization for float) -template <> CUTLASS_HOST_DEVICE float e() { - uint32_t bits = 0x402df854u; - return reinterpret_cast(bits); -} - -/// Returns e, approximately 2.718... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex e< complex >() { - return complex(e(), float()); -} - -/// Returns (1/2) (specialization for float) -template <> CUTLASS_HOST_DEVICE float half() { - uint32_t bits = 0x3f000000u; - return reinterpret_cast(bits); -} - -/// Returns (1/2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half< complex >() { - return complex(half(), float()); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for float) -template <> CUTLASS_HOST_DEVICE float root_two() { - uint32_t bits = 0x3fb504f3u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { - return complex(root_two(), float()); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) -template <> CUTLASS_HOST_DEVICE float half_root_two() { - uint32_t bits = 0x3f3504f3u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { - return complex(half_root_two(), float()); -} - -/// Returns ln(2), approximately 0.693... (specialization for float) -template <> CUTLASS_HOST_DEVICE float ln_two() { - uint32_t bits = 0x3f317218u; - return reinterpret_cast(bits); -} - -/// Returns ln(2), approximately 0.693... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { - return complex(ln_two(), float()); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) -template <> CUTLASS_HOST_DEVICE float ln_ln_two() { - uint32_t bits = 0xbebba795u; - return reinterpret_cast(bits); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { - return complex(ln_ln_two(), float()); -} - -/// Returns 1/3, approximately 0.333... (specialization for float) -template <> CUTLASS_HOST_DEVICE float third() { - uint32_t bits = 0x3eaaaaabu; - return reinterpret_cast(bits); -} - -/// Returns 1/3, approximately 0.333... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex third< complex >() { - return complex(third(), float()); -} - -/// Returns 2/3, approximately 0.666... (specialization for float) -template <> CUTLASS_HOST_DEVICE float twothirds() { - uint32_t bits = 0x3f2aaaabu; - return reinterpret_cast(bits); -} - -/// Returns 2/3, approximately 0.666... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { - return complex(twothirds(), float()); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for float) -template <> CUTLASS_HOST_DEVICE float pi_minus_three() { - uint32_t bits = 0x3e10fdaau; - return reinterpret_cast(bits); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { - return complex(pi_minus_three(), float()); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for float) -template <> CUTLASS_HOST_DEVICE float four_minus_pi() { - uint32_t bits = 0x3f5bc095u; - return reinterpret_cast(bits); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { - return complex(four_minus_pi(), float()); -} - -///////////////////////////////////////////////////////////////////////////////////// - -// Specialization for tfloat32_t - -/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t one() { - uint32_t bits = 0x3f801000u; - return reinterpret_cast(bits); -} - -/// Returns 1, the multiplicative identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex one< complex >() { - return complex(one(), tfloat32_t()); -} - -/// Returns 0, the additive identity element (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { - uint32_t bits = 0x1000u; - return reinterpret_cast(bits); -} - -/// Returns 0, the additive identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex zero< complex >() { - return complex(zero(), tfloat32_t()); -} - -/// Returns 2 (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t two() { - uint32_t bits = 0x40001000u; - return reinterpret_cast(bits); -} - -/// Returns 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two< complex >() { - return complex(two(), tfloat32_t()); -} - -/// Returns pi, approximately 3.141 (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { - uint32_t bits = 0x40491fdbu; - return reinterpret_cast(bits); -} - -/// Returns pi, approximately 3.141 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi< complex >() { - return complex(pi(), tfloat32_t()); -} - -/// Returns 2 * pi (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { - uint32_t bits = 0x40c91fdbu; - return reinterpret_cast(bits); -} - -/// Returns 2 * pi (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { - return complex(two_pi(), tfloat32_t()); -} - -/// Returns pi / 2 (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { - uint32_t bits = 0x3fc91fdbu; - return reinterpret_cast(bits); -} - -/// Returns pi / 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { - return complex(half_pi(), tfloat32_t()); -} - -/// Returns sqrt(pi) (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { - uint32_t bits = 0x3fe2efc5u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { - return complex(root_pi(), tfloat32_t()); -} - -/// Returns sqrt(pi / 2) (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { - uint32_t bits = 0x3fa07c99u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi / 2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { - return complex(root_half_pi(), tfloat32_t()); -} - -/// Returns sqrt(2 * pi) (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { - uint32_t bits = 0x40207c99u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2 * pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { - return complex(root_two_pi(), tfloat32_t()); -} - -/// Returns sqrt(ln(4)) (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { - uint32_t bits = 0x3f96c55fu; - return reinterpret_cast(bits); -} - -/// Returns sqrt(ln(4)) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { - return complex(root_ln_four(), tfloat32_t()); -} - -/// Returns e, approximately 2.718... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t e() { - uint32_t bits = 0x402e0854u; - return reinterpret_cast(bits); -} - -/// Returns e, approximately 2.718... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex e< complex >() { - return complex(e(), tfloat32_t()); -} - -/// Returns (1/2) (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t half() { - uint32_t bits = 0x3f001000u; - return reinterpret_cast(bits); -} - -/// Returns (1/2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half< complex >() { - return complex(half(), tfloat32_t()); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { - uint32_t bits = 0x3fb514f3u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { - return complex(root_two(), tfloat32_t()); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { - uint32_t bits = 0x3f3514f3u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { - return complex(half_root_two(), tfloat32_t()); -} - -/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { - uint32_t bits = 0x3f318218u; - return reinterpret_cast(bits); -} - -/// Returns ln(2), approximately 0.693... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { - return complex(ln_two(), tfloat32_t()); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { - uint32_t bits = 0xbebbb795u; - return reinterpret_cast(bits); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { - return complex(ln_ln_two(), tfloat32_t()); -} - -/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t third() { - uint32_t bits = 0x3eaabaabu; - return reinterpret_cast(bits); -} - -/// Returns 1/3, approximately 0.333... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex third< complex >() { - return complex(third(), tfloat32_t()); -} - -/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { - uint32_t bits = 0x3f2abaabu; - return reinterpret_cast(bits); -} - -/// Returns 2/3, approximately 0.666... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { - return complex(twothirds(), tfloat32_t()); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { - uint32_t bits = 0x3e110daau; - return reinterpret_cast(bits); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { - return complex(pi_minus_three(), tfloat32_t()); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) -template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { - uint32_t bits = 0x3f5bd095u; - return reinterpret_cast(bits); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { - return complex(four_minus_pi(), tfloat32_t()); -} - -///////////////////////////////////////////////////////////////////////////////////// - -// Specialization for half_t - -/// Returns 1, the multiplicative identity element (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t one() { - uint16_t bits = 0x3c00u; - return reinterpret_cast(bits); -} - -/// Returns 1, the multiplicative identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex one< complex >() { - return complex(one(), half_t()); -} - -/// Returns 0, the additive identity element (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t zero() { - uint16_t bits = 0x0u; - return reinterpret_cast(bits); -} - -/// Returns 0, the additive identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex zero< complex >() { - return complex(zero(), half_t()); -} - -/// Returns 2 (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t two() { - uint16_t bits = 0x4000u; - return reinterpret_cast(bits); -} - -/// Returns 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two< complex >() { - return complex(two(), half_t()); -} - -/// Returns pi, approximately 3.141 (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t pi() { - uint16_t bits = 0x4248u; - return reinterpret_cast(bits); -} - -/// Returns pi, approximately 3.141 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi< complex >() { - return complex(pi(), half_t()); -} - -/// Returns 2 * pi (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t two_pi() { - uint16_t bits = 0x4648u; - return reinterpret_cast(bits); -} - -/// Returns 2 * pi (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { - return complex(two_pi(), half_t()); -} - -/// Returns pi / 2 (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t half_pi() { - uint16_t bits = 0x3e48u; - return reinterpret_cast(bits); -} - -/// Returns pi / 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { - return complex(half_pi(), half_t()); -} - -/// Returns sqrt(pi) (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t root_pi() { - uint16_t bits = 0x3f17u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { - return complex(root_pi(), half_t()); -} - -/// Returns sqrt(pi / 2) (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { - uint16_t bits = 0x3d03u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi / 2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { - return complex(root_half_pi(), half_t()); -} - -/// Returns sqrt(2 * pi) (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { - uint16_t bits = 0x4103u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2 * pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { - return complex(root_two_pi(), half_t()); -} - -/// Returns sqrt(ln(4)) (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { - uint16_t bits = 0x3cb6u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(ln(4)) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { - return complex(root_ln_four(), half_t()); -} - -/// Returns e, approximately 2.718... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t e() { - uint16_t bits = 0x4170u; - return reinterpret_cast(bits); -} - -/// Returns e, approximately 2.718... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex e< complex >() { - return complex(e(), half_t()); -} - -/// Returns (1/2) (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t half() { - uint16_t bits = 0x3800u; - return reinterpret_cast(bits); -} - -/// Returns (1/2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half< complex >() { - return complex(half(), half_t()); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t root_two() { - uint16_t bits = 0x3da8u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { - return complex(root_two(), half_t()); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t half_root_two() { - uint16_t bits = 0x39a8u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { - return complex(half_root_two(), half_t()); -} - -/// Returns ln(2), approximately 0.693... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t ln_two() { - uint16_t bits = 0x398cu; - return reinterpret_cast(bits); -} - -/// Returns ln(2), approximately 0.693... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { - return complex(ln_two(), half_t()); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { - uint16_t bits = 0xb5ddu; - return reinterpret_cast(bits); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { - return complex(ln_ln_two(), half_t()); -} - -/// Returns 1/3, approximately 0.333... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t third() { - uint16_t bits = 0x3555u; - return reinterpret_cast(bits); -} - -/// Returns 1/3, approximately 0.333... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex third< complex >() { - return complex(third(), half_t()); -} - -/// Returns 2/3, approximately 0.666... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t twothirds() { - uint16_t bits = 0x3955u; - return reinterpret_cast(bits); -} - -/// Returns 2/3, approximately 0.666... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { - return complex(twothirds(), half_t()); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { - uint16_t bits = 0x3088u; - return reinterpret_cast(bits); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { - return complex(pi_minus_three(), half_t()); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for half_t) -template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { - uint16_t bits = 0x3adeu; - return reinterpret_cast(bits); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { - return complex(four_minus_pi(), half_t()); -} - -///////////////////////////////////////////////////////////////////////////////////// - -// Specialization for bfloat16_t - -/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t one() { - uint16_t bits = 0x3f80u; - return reinterpret_cast(bits); -} - -/// Returns 1, the multiplicative identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex one< complex >() { - return complex(one(), bfloat16_t()); -} - -/// Returns 0, the additive identity element (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { - uint16_t bits = 0x0u; - return reinterpret_cast(bits); -} - -/// Returns 0, the additive identity element (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex zero< complex >() { - return complex(zero(), bfloat16_t()); -} - -/// Returns 2 (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t two() { - uint16_t bits = 0x4000u; - return reinterpret_cast(bits); -} - -/// Returns 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two< complex >() { - return complex(two(), bfloat16_t()); -} - -/// Returns pi, approximately 3.141 (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { - uint16_t bits = 0x4049u; - return reinterpret_cast(bits); -} - -/// Returns pi, approximately 3.141 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi< complex >() { - return complex(pi(), bfloat16_t()); -} - -/// Returns 2 * pi (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { - uint16_t bits = 0x40c9u; - return reinterpret_cast(bits); -} - -/// Returns 2 * pi (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { - return complex(two_pi(), bfloat16_t()); -} - -/// Returns pi / 2 (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { - uint16_t bits = 0x3fc9u; - return reinterpret_cast(bits); -} - -/// Returns pi / 2 (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { - return complex(half_pi(), bfloat16_t()); -} - -/// Returns sqrt(pi) (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { - uint16_t bits = 0x3fe3u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { - return complex(root_pi(), bfloat16_t()); -} - -/// Returns sqrt(pi / 2) (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { - uint16_t bits = 0x3fa0u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(pi / 2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { - return complex(root_half_pi(), bfloat16_t()); -} - -/// Returns sqrt(2 * pi) (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { - uint16_t bits = 0x4020u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2 * pi) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { - return complex(root_two_pi(), bfloat16_t()); -} - -/// Returns sqrt(ln(4)) (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { - uint16_t bits = 0x3f97u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(ln(4)) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { - return complex(root_ln_four(), bfloat16_t()); -} - -/// Returns e, approximately 2.718... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t e() { - uint16_t bits = 0x402eu; - return reinterpret_cast(bits); -} - -/// Returns e, approximately 2.718... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex e< complex >() { - return complex(e(), bfloat16_t()); -} - -/// Returns (1/2) (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t half() { - uint16_t bits = 0x3f00u; - return reinterpret_cast(bits); -} - -/// Returns (1/2) (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half< complex >() { - return complex(half(), bfloat16_t()); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { - uint16_t bits = 0x3fb5u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2), approximately 1.414... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { - return complex(root_two(), bfloat16_t()); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { - uint16_t bits = 0x3f35u; - return reinterpret_cast(bits); -} - -/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { - return complex(half_root_two(), bfloat16_t()); -} - -/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { - uint16_t bits = 0x3f31u; - return reinterpret_cast(bits); -} - -/// Returns ln(2), approximately 0.693... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { - return complex(ln_two(), bfloat16_t()); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { - uint16_t bits = 0xbebcu; - return reinterpret_cast(bits); -} - -/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { - return complex(ln_ln_two(), bfloat16_t()); -} - -/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t third() { - uint16_t bits = 0x3eabu; - return reinterpret_cast(bits); -} - -/// Returns 1/3, approximately 0.333... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex third< complex >() { - return complex(third(), bfloat16_t()); -} - -/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { - uint16_t bits = 0x3f2bu; - return reinterpret_cast(bits); -} - -/// Returns 2/3, approximately 0.666... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { - return complex(twothirds(), bfloat16_t()); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { - uint16_t bits = 0x3e11u; - return reinterpret_cast(bits); -} - -/// Returns pi - 3, approximately 0.1416... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { - return complex(pi_minus_three(), bfloat16_t()); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) -template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { - uint16_t bits = 0x3f5cu; - return reinterpret_cast(bits); -} - -/// Returns 4 - pi, approximately 0.858... (specialization for complex) -template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { - return complex(four_minus_pi(), bfloat16_t()); -} -/////////////////////////////////////////////////////////////////////////////////// - -} // namespace constants -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp deleted file mode 100644 index e032f9599a5e76eb1e8dd6b5279ae9a42ce9c9b4..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp +++ /dev/null @@ -1,94 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/conv/collective/collective_conv.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Used to specify stage counts or dispatch to automatic computation of stage count -template -struct StageCount { - static constexpr int value = num_stages; - - StageCount() = default; - explicit StageCount(cute::Int) {} -}; - -template -struct StageCountAutoCarveout { - static constexpr int bytes = carveout_bytes; - - StageCountAutoCarveout() = default; - explicit StageCountAutoCarveout(cute::Int) {} -}; - -// Used to automatically let the builder pick the kernel schedule. -// Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp -struct KernelScheduleAuto {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class ArchTag, - class OpClass, - conv::Operator, - class ElementA, - class GmemLayoutA, - int AlignmentA, - class ElementB, - class GmemLayoutB, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - class KernelScheduleType, - class Enable = void -> -struct CollectiveBuilder { - static_assert(cutlass::detail::dependent_false, "Could not build a collective for given parameters."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "builders/sm90_gmma_builder.inl" -#include "builders/sm100_umma_builder.inl" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp deleted file mode 100644 index f0bb596fe02b36d350a1d1065ff5001794eba170..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/conv/collective/detail.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class DispatchPolicy, - class TileShape, - class ElementA, - class ElementB, - class TiledMma, - class TileTraitsA, - class TileTraitsB -> -struct CollectiveConv { - static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp" -#include "sm100_implicit_gemm_umma_warpspecialized.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp deleted file mode 100644 index af541a940f787528d213f068915ce0aa5997a82f..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp +++ /dev/null @@ -1,271 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/conv/convnd_problem_shape.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::collective::detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default -template -constexpr auto -sm90_dispatch_policy_to_stride_A() { - if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { - // Maps to modes ((w,n), C) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { - return cute::Stride, - cute::Int<1>>{}; - } - // Maps to modes ((w,h,n), C) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { - return cute::Stride, - cute::Int<1>>{}; - } - // Maps to modes ((w,h,d,n), C) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, - cute::Int<1>>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { - // Maps to modes (k, nq/npq/nzpq) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1 || - DispatchPolicy::NumSpatialDimensions == 2 || - DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, int64_t>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { - // Maps to modes ((q,n), K) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { - return cute::Stride, - cute::Int<1>>{}; - } - // Maps to modes ((q,p,n), K) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { - return cute::Stride, - cute::Int<1>>{}; - } - // Maps to modes ((q,p,z,n), K) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, - cute::Int<1>>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); - } -} - -// Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default -template -constexpr auto -sm90_dispatch_policy_to_stride_B() { - if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { - // Maps to modes (k, (C,s)) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { - return cute::Stride, int64_t>>{}; - } - // Maps to modes (k, (C,s,r)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { - return cute::Stride, int64_t, int64_t>>{}; - } - // Maps to modes (k, (C,s,r,t)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, int64_t, int64_t, int64_t>>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { - // Maps to modes (C, (w,n)) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { - return cute::Stride, - cute::Stride>{}; - } - // Maps to modes (C, (w,h,n)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { - return cute::Stride, - cute::Stride>{}; - } - // Maps to modes (C, (w,h,d,n)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, - cute::Stride>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { - // Maps to modes (C, (k,s)) - if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { - return cute::Stride, cute::Stride>{}; - } - // Maps to modes (C, (k,s,r)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { - return cute::Stride, cute::Stride>{}; - } - // Maps to modes (C, (k,s,r,t)) - else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { - return cute::Stride, cute::Stride>{}; - } - // error dims assert - else { - static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); - } -} - - -template -constexpr auto -sm100_dispatch_policy_to_stride_A() { - return sm90_dispatch_policy_to_stride_A(); -} - -template -constexpr auto -sm100_dispatch_policy_to_stride_B() { - return sm90_dispatch_policy_to_stride_B(); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Compute the lower/near corner, returning it as a cute::array in [W,H,D] order -template -CUTLASS_HOST_DEVICE -constexpr auto -compute_lower_corner_whd(ConvProblemShape const& problem_shape) { - using cute::for_each; - using cute::make_seq; - - cute::array lower{}; - if constexpr (ConvOp == conv::Operator::kFprop || - ConvOp == conv::Operator::kWgrad) { - for_each(make_seq{}, [&](auto i) { - lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i]; - }); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - for_each(make_seq{}, [&](auto i) { - lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - - (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; - }); - } - return lower; -} - -// Computes the upper/far corner, returning it as a cute::array in [W,H,D] order -template -CUTLASS_HOST_DEVICE -constexpr auto -compute_upper_corner_whd(ConvProblemShape const& problem_shape) { - using cute::for_each; - using cute::make_seq; - - cute::array upper{}; - if constexpr (ConvOp == conv::Operator::kFprop) { - for_each(make_seq{}, [&](auto i) { - upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - - (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; - }); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - for_each(make_seq{}, [&](auto i) { - upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - - (problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i]; - }); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - for_each(make_seq{}, [&](auto i) { - upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - - (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1]; - }); - } - return upper; -} - -// Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order -template -CUTLASS_HOST_DEVICE -constexpr auto -compute_lower_srt(ConvProblemShape const& problem_shape) { - using cute::for_each; - using cute::make_seq; - - cute::array lower{}; - if constexpr (ConvOp == conv::Operator::kFprop || - ConvOp == conv::Operator::kWgrad) { - for_each(make_seq{}, [&](auto i) { - lower[NumSpatialDimensions-1-i] = 0; - }); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - for_each(make_seq{}, [&](auto i) { - lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; - }); - } - return lower; -} - -template struct is_im2col_load { static constexpr bool value = false; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::collective::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp deleted file mode 100644 index d3c541c325004eb8488ca7353eed9a43fa4ae280..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ /dev/null @@ -1,917 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/detail/cluster.hpp" - -#include "cutlass/conv/detail.hpp" -#include "cute/algorithm/functional.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cutlass/trace.h" - -#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) -# include -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one -template < - conv::Operator ConvOp, - int Stages, - int NumSpatialDims, - int SchedulerPipelineStageCount, - int AccumulatorPipelineStageCount, - class ClusterShape, // Static cluster shape or dynamic (int, int, _1) - class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) - class ElementA_, - class ElementB_, - class TiledMma_, - class TileTraitsA_, - class TileTraitsB_> -struct CollectiveConv< - MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, - Stages, - NumSpatialDims, - SchedulerPipelineStageCount, - AccumulatorPipelineStageCount, - ClusterShape>, - TileShapeMNKL_, - ElementA_, - ElementB_, - TiledMma_, - TileTraitsA_, - TileTraitsB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, - Stages, - NumSpatialDims, - SchedulerPipelineStageCount, - AccumulatorPipelineStageCount, - ClusterShape>; - using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) - using ElementA = ElementA_; - using ElementB = ElementB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; - using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; - using SmemLayoutAtomA = typename TileTraitsA_::SmemLayoutAtom; - using SmemLayoutAtomB = typename TileTraitsB_::SmemLayoutAtom; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; - static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; - // deducde the kernel facing stride tuple types based on the dispatch policy (spatial dim, algo, etc.) - using StrideA = decltype(detail::sm100_dispatch_policy_to_stride_A()); - using StrideB = decltype(detail::sm100_dispatch_policy_to_stride_B()); - - static constexpr bool IsDynamicCluster = not cute::is_static_v; - static constexpr bool ConvertF32toTF32A = cute::is_same_v; - static constexpr bool ConvertF32toTF32B = cute::is_same_v; - using TmaInternalElementA = cute::conditional_t>>; - using TmaInternalElementB = cute::conditional_t>>; - - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; - - // Determine MMA type: MMA_1SM vs MMA_2SM - using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; - - using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< - DispatchPolicy::Stages, - ClusterShape, - AtomThrShapeMNK>; - using MainloopPipelineState = typename MainloopPipeline::PipelineState; - - using ProblemShape = ConvProblemShape; - - CUTE_STATIC_ASSERT_V(evenly_divides(shape<0>(TileShape{}), tile_size<0>(TiledMma{})), "TileShape_M should be evenly divided by TiledMma_M"); - CUTE_STATIC_ASSERT_V(evenly_divides(shape<1>(TileShape{}), tile_size<1>(TiledMma{})) || (ConvOp == conv::Operator::kWgrad), "TileShape_N should be evenly divided by TiledMma_N"); - - using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); - - // Define A and B block shapes for reduced size TMA_LOADs - using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); - using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); - - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - - // Tile along K mode first before tiling over MN. PIPE mode last as usual. - // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. - using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( - SmemLayoutAtomA{}, - append(MmaShapeA_MK{}, Int{}), - Step<_2,_1,_3>{})); - using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( - SmemLayoutAtomB{}, - append(MmaShapeB_NK{}, Int{}), - Step<_2,_1,_3>{})); - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - - static constexpr bool is_im2col_A = detail::is_im2col_load::value; - static constexpr bool is_im2col_B = detail::is_im2col_load::value; - static constexpr bool is_strided_dgrad = ConvOp == conv::Operator::kDgrad && not is_im2col_A && not is_im2col_B; - - static constexpr int TileShapeMNKLRank = rank(TileShapeMNKL_{}); - // If rank > 3, TileL exists and it is GroupsPerTile. The kernel is grouped conv now. - static constexpr bool is_grouped_wgrad = ConvOp == conv::Operator::kWgrad && TileShapeMNKLRank > 3; - - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _0> { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly - static constexpr uint32_t TmaTransactionBytes = - size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof(ElementA))) + - size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A{nullptr}; - ElementB const* ptr_B{nullptr}; - }; - -private: - - // Note that for fprop and non-strided dgrad kernel, the tma load mode is im2col for tensor A and tiled for - // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor - // B since operand A, B is swapped. - // For strided dgrad A and B are both tma tiled and not im2col - - template - static constexpr auto - get_tma_load_a_instance( - TensorA const& tensor_a, - ProblemShape const& problem_shape, - ClusterShapeVMNK const& cluster_shape_vmnk) { - - if constexpr (is_im2col_A) { - // compute the upper and lower corners based on the conv padding - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - auto lower_srt = detail::compute_lower_srt(problem_shape); - - // gbasis strides for dgrad kernel need to be negated - cute::array stride_srt{}; - for (int i = 0; i < NumSpatialDimensions; ++i) { - stride_srt[i] = ConvOp == conv::Operator::kDgrad ? - -problem_shape.dilation[NumSpatialDimensions-1-i] : - problem_shape.dilation[NumSpatialDimensions-1-i]; - } - - return make_im2col_tma_atom_A_sm100( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,_,cute::Int<0>{}), - TileShape{}, - TiledMma{}, - cluster_shape_vmnk, - shape(lower_corner_whd), - shape(upper_corner_whd), - cute::reverse(shape(problem_shape.lower_padding)), - cute::reverse(shape(problem_shape.upper_padding)), - cute::reverse(shape(problem_shape.traversal_stride)), - shape(lower_srt), - shape(stride_srt)); - } - // TMA tiled mode for tensor A in wgrad and strided dgrad - else { - return make_tma_atom_A_sm100( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,_,cute::Int<0>{}), - TileShape{}, - TiledMma{}, - cluster_shape_vmnk); - } - } - - template - static constexpr auto - get_tma_load_b_instance( - TensorB const& tensor_b, - ProblemShape const& problem_shape, - ClusterShapeVMNK const& cluster_shape_vmnk) { - - if constexpr (is_im2col_B) { - // compute the upper and lower corners based on the conv padding - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - auto lower_srt = detail::compute_lower_srt(problem_shape); - - return make_im2col_tma_atom_B_sm100( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,_,cute::Int<0>{}), - TileShape{}, - TiledMma{}, - cluster_shape_vmnk, - shape(lower_corner_whd), - shape(upper_corner_whd), - cute::reverse(shape(problem_shape.lower_padding)), - cute::reverse(shape(problem_shape.upper_padding)), - cute::reverse(shape(problem_shape.traversal_stride)), - shape(lower_srt), - cute::reverse(shape(problem_shape.dilation))); - } - else { - return make_tma_atom_B_sm100( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,_,cute::Int<0>{}), - TileShape{}, - TiledMma{}, - cluster_shape_vmnk); - } - } - -public: - - // Performs im2col transformations on the input of type ConvProblemShape - static constexpr auto - get_problem_shape_MNKL(ProblemShape const& problem_shape) { - if constexpr (is_im2col_A || is_im2col_B) { - // transformation + im2col linearization - return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); - } - else { - // transformation - return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); - } - } - - // Device-side kernel params - // - // Arguments has the untransformed problem shape from the user. - // Params will have the transformed problem shape. - struct Params { - using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); - - using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), - make_tile(typename TiledMma::AtomThrID{}))); - - // Assumption: StrideA is congruent with Problem_MK - // Select TMA load type according to convolution operator. - using TensorShapeA = cute::conditional_t; - - using TensorShapeB = cute::conditional_t; - - using TMA_A = decltype(get_tma_load_a_instance( - make_tensor( - make_gmem_ptr(recast_ptr(nullptr)), - make_layout(TensorShapeA{}, StrideA{})), - ConvProblemShape{}, - ClusterLayout_VMNK{})); - - using TMA_B = decltype(get_tma_load_b_instance( - make_tensor( - make_gmem_ptr(recast_ptr(nullptr)), - make_layout(TensorShapeB{}, StrideB{})), - ConvProblemShape{}, - ClusterLayout_VMNK{})); - - // Members - TMA_A tma_load_a; - TMA_B tma_load_b; - TMA_A tma_load_a_fallback; - TMA_B tma_load_b_fallback; - dim3 cluster_shape_fallback; - }; - - // - // Constructor - // - CUTLASS_DEVICE - CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) - : cluster_shape_(cluster_shape) - , block_rank_in_cluster_(block_rank_in_cluster) { - if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && - cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); - observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; - observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; - } - else { - observed_tma_load_a_ = ¶ms.tma_load_a; - observed_tma_load_b_ = ¶ms.tma_load_b; - } - } - - // - // Methods - // - - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { - (void) workspace; - - // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple - // tma desc creation depends on the original untransformed domain. - - // A extents. - auto shape_A_orig = problem_shape.get_shape_A(); - // B extents. - auto shape_B_orig = problem_shape.get_shape_B(); - - // Fill inferred cute strides from flat stride arrays - auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); - auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); - - auto ptr_A = recast_ptr(args.ptr_A); - auto ptr_B = recast_ptr(args.ptr_B); - - Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); - Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); - - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); - // Cluster layout for TMA construction - auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); - auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); - - // Cluster layout for TMA construction - auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); - - auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk); - auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk); - auto tma_load_a_fallback = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk_fallback); - auto tma_load_b_fallback = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk_fallback); - - static_assert(size(typename decltype(tma_load_a)::ThrID{}) == size(AtomThrShapeMNK{})); - static_assert(size(typename decltype(tma_load_b)::ThrID{}) == size(AtomThrShapeMNK{})); - - return { - tma_load_a, - tma_load_b, - tma_load_a_fallback, - tma_load_b_fallback, - hw_info.cluster_shape_fallback - }; - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - Arguments const& args) { - // Activation and Filter channel mode extents much match - bool implementable = true; - // channel mode is major - { - const bool check = problem_shape.stride_A[NumTensorDimensions-1] == 1; -#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) - if (not check) { - const auto offending_stride = - problem_shape.stride_A[NumTensorDimensions-1]; - std::ostringstream os; - os << "CollectiveConv::can_implement: " - "problem_shape.stride_A[NumTensorDimensions-1 = " - << (NumTensorDimensions-1) << "] = " - << offending_stride << " != 1"; - CUTLASS_TRACE_HOST( os.str() ); - } -#endif - implementable &= check; - } - - { - const bool check = problem_shape.stride_B[NumTensorDimensions-1] == 1; -#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) - if (not check) { - const auto offending_stride = - problem_shape.stride_B[NumTensorDimensions-1]; - std::ostringstream os; - os << "CollectiveConv::can_implement: " - "problem_shape.stride_B[NumTensorDimensions-1 = " - << (NumTensorDimensions-1) << "] = " - << offending_stride << " != 1\n"; - CUTLASS_TRACE_HOST( os.str() ); - } -#endif - implementable &= check; - } - - { - const auto & traversal_stride = problem_shape.traversal_stride; - for (auto stride: traversal_stride) { - implementable &= (stride >= 1 && stride <= 8); - } - } - - if constexpr (ConvOp == conv::Operator::kDgrad && not is_strided_dgrad) { - const auto & traversal_stride = problem_shape.traversal_stride; - for (auto stride: traversal_stride) { - implementable &= (stride == 1); - } - } - - constexpr int tma_alignment_bits = 128; - // A extents. - auto shape_A_orig = problem_shape.get_shape_A(); - // B extents. - auto shape_B_orig = problem_shape.get_shape_B(); - - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - { - const bool check = cutlass::detail::check_alignment(shape_A_orig, StrideA{}); - if (not check) { - CUTLASS_TRACE_HOST("A shape and/or strides have alignment issue."); - } - implementable &= check; - } - - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - { - const bool check = cutlass::detail::check_alignment(shape_B_orig, StrideB{}); - if (not check) { - CUTLASS_TRACE_HOST("B shape and/or strides have alignment issue."); - } - implementable &= check; - } - - if (not implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - return false; - } - - if (is_im2col_A || is_im2col_B) { - // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] - constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); - } - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - if (is_im2col_A || is_im2col_B) { - // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit] - constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; - auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; - for (int i = 0; i < problem_shape.RankS; ++i) { - // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array - implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) - && ((flt_data[i+1] - 1) * problem_shape.dilation[i] <= offset_limit); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) - if constexpr (ConvOp == conv::Operator::kWgrad) { - - const auto & input_shape = problem_shape.shape_A; - const auto & input_stride = problem_shape.stride_A; - - implementable &= input_stride[ProblemShape::RankT - 1] == 1; - int64_t input_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - input_shape_size *= input_shape[i + 1]; - implementable &= input_stride[i] == input_shape_size; - } - - const auto & output_shape = problem_shape.shape_C; - const auto & output_stride = problem_shape.stride_C; - - implementable &= output_stride[ProblemShape::RankT - 1] == 1; - int64_t output_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - output_shape_size *= output_shape[i + 1]; - implementable &= output_stride[i] == output_shape_size; - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); - return false; - } - } - - // Conv kernels only support cross correlation mode currently. - { - implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); - return false; - } - } - - // When groups > 1, it should be a Grouped Conv. - if (problem_shape.groups > 1) { - implementable &= TileShapeMNKLRank > 3; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Only Grouped Conv can support groups > 1.\n"); - return false; - } - } - - // Only support Grouped Wgrad currently. - if constexpr (TileShapeMNKLRank > 3) { - implementable &= ConvOp == conv::Operator::kWgrad; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv Only support Grouped Wgrad currently.\n"); - return false; - } - } - - // Grouped Wgrad channel check. - if constexpr (is_grouped_wgrad) { - - int input_K = size<0>(problem_shape.get_shape_A()); - int input_C = size<0>(problem_shape.get_shape_B()); - - implementable &= input_K == input_C; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv's input K and input C do not match.\n"); - return false; - } - - int output_K = size<0>(problem_shape.get_shape_C()); - int output_C = size<1,0>(problem_shape.get_shape_C()); - - implementable &= input_K == output_K; - implementable &= input_C == output_C * problem_shape.groups; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's input and output K,C and groups do not match\n"); - return false; - } - - constexpr int Tile_N = size<1>(TileShape{}); - constexpr int GroupsPerTile = size<3>(TileShapeMNKL_{}); - - implementable &= Tile_N / GroupsPerTile == input_C / problem_shape.groups; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's Tile_N, GroupsPerTile and input_C, groups do not match.\n"); - return false; - } - } - - // The extents of linearized problem shape should be int32_t type(maximum is 2^31-1). - if constexpr (is_im2col_A || is_im2col_B) { - auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); - auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast(s); }); }; - - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits::max()) & - (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits::max()); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits::max()); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n"); - return false; - } - } - - return true; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE void - prefetch_tma_descriptors() { - cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); - cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); - } - - /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE static auto - partition_accumulator_shape() { - auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) - - return acc_shape; - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template < - class GTensorA, class GTensorB, - class GTensorPartitionedA, class GTensorPartitionedB, - class STensorA, class STensorB, - class TileCoordMNKL, - class KTileIterator - > - CUTLASS_DEVICE auto - load( - Params const& params, - MainloopPipeline pipeline, - MainloopPipelineState mainloop_pipe_producer_state, - cute::tuple const& load_inputs, - TileCoordMNKL const& cta_coord_mnkl, - KTileIterator k_tile_iter, int k_tile_count) { - - auto [unused_gA, unused_gB, - tAgA_mk, tBgB_nk, tAsA, tBsB, - mcast_mask_a, mcast_mask_b] = load_inputs; - - // slice out the work coord from partitioned tensors - Tensor tAgA = tAgA_mk(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _); - auto tensor_b_coord = get<1>(cta_coord_mnkl); - if constexpr (is_grouped_wgrad) { - // in grouped wgrad, tensor A = NZPQK, tensor B = NDHWC, tensor C = KTRSc, where C = G*c, c = channel_per_group = 8,16,32. - // CTA Tiling follows output tensor KTRSc. So cta_size_m = K/CTA_TILE_M. cta_size_n = T*R*S*ceil(c/CTA_TILE_N) = T*R*S*1 = T*R*S. - // tensor_a_coord = K_idx = cta_coord_m. - // tensor_b_coord = TRS_idx * C/CTA_TILE_N + C_idx = cta_coord_n * get<1,0>(shape(tBgB_nk) + cta_coord_m, - // because K == C and CTA_TILE_M == CTA_TILE_N => C_idx = K_idx = cta_coord_m. - tensor_b_coord = get<0>(cta_coord_mnkl) + get<1>(cta_coord_mnkl) * get<1,0>(shape(tBgB_nk)); - } - Tensor tBgB = tBgB_nk(_, tensor_b_coord, _); - - auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); - - // Issue the Mainloop loads - CUTLASS_PRAGMA_NO_UNROLL - while (k_tile_count > 0) { - // LOCK mainloop_pipe_producer_state for _writing_ - pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); - - int write_stage = mainloop_pipe_producer_state.index(); - ++mainloop_pipe_producer_state; - barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); - - if constexpr (is_strided_dgrad) { - // construct gemm-k tile coord for gB - auto [conv_k, flt_coord, out_coord] = *k_tile_iter; - auto gemm_k_tile = prepend(flt_coord, conv_k); // (k,s,r,t) - - // gA doesn't have a gemm-k (k,s,r,t) iterator mode because it's not an im2col tensor - auto offset_kqpzn = append(prepend(out_coord, _0{}),_0{}); // (k,q,p,z,n) - auto tAgA_offset = make_tensor(tAgA.data() + offset_kqpzn, tAgA.layout()); // (TMA, k) - - if (cute::elect_one_sync()) { - copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA_offset(_,conv_k), tAsA(_,write_stage)); - copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,gemm_k_tile) , tBsB(_,write_stage)); - } - } - else { - if (cute::elect_one_sync()) { - copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); - copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); - } - } - - --k_tile_count; - ++k_tile_iter; - } - - return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); - } - - /// Set up the data needed by this collective for load. - /// Return tuple element contain - /// gA_mk - The tiled tma tensor for input A - /// gB_nk - The tiled tma tensor for input B - /// tAsA - partitioned smem tensor for A - /// tBsB - partitioned smem tensor for B - /// mcast_mask_a - tma multicast mask for A - /// mcast_mask_b - tma multicast mask for B - template - CUTLASS_DEVICE auto - load_init( - ProblemShape_MNKL const& problem_shape_MNKL, - Params const& params, - TensorStorage& shared_tensors) const { - using X = Underscore; - - // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; - - // Represent the full tensors -- get these from TMA - auto K_A = conditional_return(get<0>(K), K); - Tensor mA_mk = observed_tma_load_a_->get_tma_tensor(make_shape(M, K_A)); - Tensor mB_nk = observed_tma_load_b_->get_tma_tensor(make_shape(N, K)); - - // Tile the tensors and defer the slice - Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k) - Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k) - - // Partition for this CTA - ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); - - Tensor tCgA_mk = cta_mma.partition_A(gA_mk); // (MMA, MMA_M, MMA_K, m, k) - Tensor tCgB_nk = cta_mma.partition_B(gB_nk); // (MMA, MMA_N, MMA_K, n, k) - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) - - // Define the CTA-in-cluster Layout and Coord - Layout cta_layout_mnk = make_layout(cluster_shape_); - Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); - - // Project the cta_layout for tma_a along the n-modes - auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_, - get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0,3>(sA), group_modes<0,3>(tCgA_mk)); - - // Project the cta_layout for tma_b along the m-modes - auto [tBgB_nk, tBsB] = tma_partition(*observed_tma_load_b_, - get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), - group_modes<0,3>(sB), group_modes<0,3>(tCgB_nk)); - - // TMA Multicast Masks - uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); - uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - - return cute::make_tuple( - gA_mk, gB_nk, // for scheduler - tAgA_mk, tBgB_nk, tAsA, tBsB, // for input tensor values - mcast_mask_a, mcast_mask_b); // multicast masks - } - - /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster - CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { - // Issue the epilogue waits - /* This helps avoid early exit of ctas in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(mainloop_pipe_producer_state); - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template < - class FrgEngine, class FrgLayout, - class FragmentA, class FragmentB - > - CUTLASS_DEVICE auto - mma(MainloopPipeline pipeline, - MainloopPipelineState mainloop_pipe_consumer_state, - cute::Tensor& accumulators, - cute::tuple const& mma_inputs, - int k_tile_count) - { - static_assert(is_tmem::value, "Accumulator must be tmem resident."); - static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); - - auto [tiled_mma, tCrA, tCrB] = mma_inputs; - - uint32_t skip_wait = k_tile_count <= 0; - auto barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - - // - // PIPELINED MAIN LOOP - // - tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; - - CUTLASS_PRAGMA_NO_UNROLL - while (k_tile_count > 0) { - // WAIT on mainloop_pipe_consumer_state until its data are available (phase bit flips from mainloop_pipe_consumer_state.phase() value) - pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - - // Compute on k_tile - int read_stage = mainloop_pipe_consumer_state.index(); - // Save current mainlop pipeline read state - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - - // Advance mainloop_pipe - ++mainloop_pipe_consumer_state; - --k_tile_count; - skip_wait = k_tile_count <= 0; - // Peek at next iteration - barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - - // Unroll the K mode manually so we can set scale C to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } - pipeline.consumer_release(curr_mainloop_pipe_consumer_state); - } - - return mainloop_pipe_consumer_state; - } - - CUTLASS_DEVICE auto - mma_init(TensorStorage& shared_tensors) const { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - TiledMma tiled_mma; - - // Allocate "fragments/descriptors" for A and B matrices - Tensor tCrA = tiled_mma.make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = tiled_mma.make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE - return cute::make_tuple(tiled_mma, tCrA, tCrB); - } - -private: - - typename Params::TMA_A const* observed_tma_load_a_ = nullptr; - typename Params::TMA_B const* observed_tma_load_b_ = nullptr; - - ClusterShape cluster_shape_; - uint32_t block_rank_in_cluster_; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp deleted file mode 100644 index 11eefed94182c8d8870a65c9f4d937ede5db5421..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ /dev/null @@ -1,785 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/atom/copy_traits_sm90_im2col.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/algorithm/functional.hpp" -#include "cute/algorithm/gemm.hpp" - -#include "cutlass/conv/detail.hpp" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/dispatch_policy.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/util/packed_stride.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - conv::Operator ConvOp, - int Stages, - int NumSpatialDims, - class ClusterShape, - class KernelSchedule, - int PipelineAsyncMmaStages, - class TileShape_, - class ElementA_, - class ElementB_, - class TiledMma_, - class TileTraitsA_, - class TileTraitsB_> -struct CollectiveConv< - MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>, - TileShape_, - ElementA_, - ElementB_, - TiledMma_, - TileTraitsA_, - TileTraitsB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>; - using TileShape = TileShape_; - using ElementA = ElementA_; - using ElementB = ElementB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; - using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; - using SmemLayoutA = typename TileTraitsA_::SmemLayout; - using SmemLayoutB = typename TileTraitsB_::SmemLayout; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; - static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; - // Deduce the kernel-facing stride tuple types based on the dispatch policy - // (which is a function of the number of spatial dimensions, the algorithm, etc.) - using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A()); - using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B()); - - using MainloopPipeline = cutlass::PipelineTmaAsync; - - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename cutlass::PipelineState; - - using ProblemShape = ConvProblemShape; - - static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); - static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); - static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); - - static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); - static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); - static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - - // The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad - // kernel is im2col for tensor A and tiled for tensor B. - static_assert((ConvOp == conv::Operator::kWgrad - && (cute::is_same_v || cute::is_same_v)) - || (ConvOp != conv::Operator::kWgrad - && (cute::is_same_v || cute::is_same_v)), - "GmemTiledCopyA - invalid SM90 TMA copy atom specified."); - static_assert((ConvOp == conv::Operator::kWgrad - && (cute::is_same_v || cute::is_same_v)) - || (ConvOp != conv::Operator::kWgrad - && (cute::is_same_v || cute::is_same_v)), - "GmemTiledCopyB - invalid SM90 TMA copy atom specified."); - - static constexpr bool is_im2col_A = detail::is_im2col_load::value; - static constexpr bool is_im2col_B = detail::is_im2col_load::value; - - // TMA converts f32 input to tf32 when copying from GMEM to SMEM - // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = cute::is_same_v; - static constexpr bool ConvertF32toTF32B = cute::is_same_v; - using InternalElementA = cute::conditional_t>>; - using InternalElementB = cute::conditional_t>>; - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128, _0> { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; - static constexpr uint32_t TmaTransactionBytes = - (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ - (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A{nullptr}; - ElementB const* ptr_B{nullptr}; - }; - -private: - // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for - // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor - // B since operand A, B is swapped. - // Get tma_load_a instantce. - template - static constexpr auto - get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) { - if constexpr (is_im2col_A) { - // compute the upper and lower corners based on the conv padding - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - auto lower_srt = detail::compute_lower_srt(problem_shape); - - // The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values. - cute::array stride_srt{}; - for (int i = 0; i < NumSpatialDimensions; ++i) { - stride_srt[i] = ConvOp == conv::Operator::kDgrad ? - -problem_shape.dilation[NumSpatialDimensions-1-i] : - problem_shape.dilation[NumSpatialDimensions-1-i]; - } - - return make_im2col_tma_copy( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,_0{}), - product_each(shape(SmemLayoutA{}(_,_,_0{}))), - size<1>(ClusterShape{}), - shape(lower_corner_whd), - shape(upper_corner_whd), - cute::reverse(shape(problem_shape.lower_padding)), - cute::reverse(shape(problem_shape.upper_padding)), - cute::reverse(shape(problem_shape.traversal_stride)), - shape(lower_srt), - shape(stride_srt)); - } - // TMA tiled mode for tensor A in wgrad kernel. - else { - return make_tma_copy( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,_0{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); - } - } - - // Get tma_load_b instantce. - template - static constexpr auto - get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) { - // TMA im2col mode for tensor B in wgrad kernel. - if constexpr (is_im2col_B) { - // compute the upper and lower corners based on the conv padding - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - auto lower_srt = detail::compute_lower_srt(problem_shape); - - return make_im2col_tma_copy( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,_0{}), - product_each(shape(SmemLayoutB{}(_,_,_0{}))), - size<0>(ClusterShape{}), - shape(lower_corner_whd), - shape(upper_corner_whd), - cute::reverse(shape(problem_shape.lower_padding)), - cute::reverse(shape(problem_shape.upper_padding)), - cute::reverse(shape(problem_shape.traversal_stride)), - shape(lower_srt), - cute::reverse(shape(problem_shape.dilation))); - } - else { - return make_tma_copy( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,_0{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); - } - } - -public: - - // Performs im2col transformations on the input of type ConvProblemShape - static constexpr auto - get_problem_shape_MNKL(ProblemShape const& problem_shape) { - - if constexpr (is_im2col_A || is_im2col_B) { - // transformation + im2col linearization - return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); - } - else { - // transformation - return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); - } - } - - // Device side kernel params - struct Params { - using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); - - // Assumption: StrideA is congruent with Problem_MK - // Select TMA load type according to convolution operator. - using TensorShapeA = cute::conditional_t; - - using TensorShapeB = cute::conditional_t; - - using TMA_A = decltype(get_tma_load_a_instance( - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - make_layout(TensorShapeA{}, StrideA{})), - ConvProblemShape{})); - - using TMA_B = decltype(get_tma_load_b_instance( - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - make_layout(TensorShapeB{}, StrideB{})), - ConvProblemShape{})); - - // Members - TMA_A tma_load_a; - TMA_B tma_load_b; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - }; - - // - // Methods - // - - // Lowers the host side user facing arguments to the kernel facing lauch params - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple - // tma desc creation depends on the original untransformed domain. - - // A extents. - auto shape_A_orig = problem_shape.get_shape_A(); - // B extents. - auto shape_B_orig = problem_shape.get_shape_B(); - - // Fill inferred cute strides from flat stride arrays - auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); - auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); - Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); - - auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape); - auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape); - - return { - tma_load_a, - tma_load_b, - TmaTransactionBytes - }; - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - Arguments const& args) { - // Activation and Filter channel mode extents much match - bool implementable = true; - // channel mode is major - implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1; - implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1; - - constexpr int tma_alignment_bits = 128; - // A extents. - auto shape_A_orig = problem_shape.get_shape_A(); - // B extents. - auto shape_B_orig = problem_shape.get_shape_B(); - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(shape_B_orig, StrideB{}); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - return false; - } - - // Check valid padding values for TMA_LOAD_IM2COL - constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0; - implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0; - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - - if (is_im2col_A || is_im2col_B) { - // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] - constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); - } - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - if (is_im2col_A || is_im2col_B) { - // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1] - constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; - auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; - for (int i = 0; i < problem_shape.RankS; ++i) { - // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array - implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) - && ((flt_data[i+1] - 1) * problem_shape.dilation[i] < offset_limit); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) - if constexpr (ConvOp == conv::Operator::kWgrad) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::ostringstream os; -#endif - const auto & input_shape = problem_shape.shape_A; - const auto & input_stride = problem_shape.stride_A; - - implementable &= input_stride[ProblemShape::RankT - 1] == 1; - int64_t input_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - input_shape_size *= input_shape[i + 1]; - implementable &= input_stride[i] == input_shape_size; -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - if (input_stride[i] != input_shape_size) { - os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***"; - } -#endif - } - - if (!implementable) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - os << "\n input_shape_size: " << input_shape_size - << "\n input_shape: " << input_shape - << "\n input_stride: " << input_stride - << "\n"; -#endif - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST(os.str()); -#endif - return false; - } - - const auto & output_shape = problem_shape.shape_C; - const auto & output_stride = problem_shape.stride_C; - - implementable &= output_stride[ProblemShape::RankT - 1] == 1; - int64_t output_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - output_shape_size *= output_shape[i + 1]; - implementable &= output_stride[i] == output_shape_size; -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - if (output_stride[i] != output_shape_size) { - os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***"; - } -#endif - } - - if (!implementable) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - os << "\n output_shape_size: " << input_shape_size - << "\n output_shape: " << input_shape - << "\n output_stride: " << input_stride - << "\n"; -#endif - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST(os.str()); -#endif - return false; - } - } - - // Conv kernels only support cross correlation mode currently. - implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); - return false; - } - - if (problem_shape.groups > 1) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); - return false; - } - - if constexpr (is_im2col_A || is_im2col_B) { - auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); - auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast(s); }); }; - - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits::max()) & - (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits::max()); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits::max()); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n"); - return false; - } - } - - return true; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) - /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) - /// The rest of the tensors can be specified as needed by this collective. - /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with - /// StrideA and StrideB set up for TMA - template - CUTLASS_DEVICE auto - load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ - //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const { - using X = Underscore; - // Separate out problem shape for convenience - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k) - Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k) - - // Make tiled views, defer the slice - Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) - Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) - - return cute::make_tuple(gA_mk, gB_nk); - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template < - class TensorA, class TensorB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( - Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_producer_state, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - - int lane_predicate = cute::elect_one_sync(); - if (lane_predicate) { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // - // Prepare the TMA loads for A and B - // - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - - auto [gA_mk, gB_nk] = load_inputs; - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - - Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) - Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v || - cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } - } - - if constexpr (cute::is_same_v || - cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_producer_state for _writing_ - pipeline.producer_acquire(smem_pipe_producer_state); - - // - // Copy gmem to smem for *k_tile_iter - // - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); - - int write_stage = smem_pipe_producer_state.index(); - - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - ++k_tile_iter; - - // Advance smem_pipe_producer_state - ++smem_pipe_producer_state; - } - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_producer_state); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_consumer_state, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - static_assert(is_rmem::value, "C tensor must be rmem resident."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // - // Define C accumulators and A/B partitioning - // - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_consumer_state; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - warpgroup_fence_operand(accum); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { - // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) - pipeline.consumer_wait(smem_pipe_consumer_state); - - int read_stage = smem_pipe_consumer_state.index(); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - - warpgroup_commit_batch(); - - ++smem_pipe_consumer_state; - } - - warpgroup_fence_operand(accum); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) - pipeline.consumer_wait(smem_pipe_consumer_state); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_consumer_state.index(); - warpgroup_fence_operand(accum); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed - warpgroup_wait(); - warpgroup_fence_operand(accum); - - // UNLOCK smem_pipe_release, done _computing_ on it - pipeline.consumer_release(smem_pipe_release); - - // Advance smem_pipe_consumer_state and smem_pipe_release - ++smem_pipe_consumer_state; - ++smem_pipe_release; - } - - warpgroup_fence_operand(accum); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h deleted file mode 100644 index fbef858a54eda2ffbfea30e8ff9bd570bcf841f9..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h +++ /dev/null @@ -1,658 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief This file contains definitions and utility functions for describing convolution problem sizes. - - Conv2dProblem desciption: - activation (NHWC), - filter (KRSC), - output (NPQK), - pading (pad_h, pad_w), - stride (stride_h, stride_w), - dilation (dilation_h, dilation_w). - - Free functions to map: - Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) - Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) - Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm_enumerated_types.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/functional.h" - -namespace cutlass { -namespace conv { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Problem size structure -struct Conv2dProblemSize { - - // Conv2d strictly problem size parameters - int N, H, W, C, P, Q, K, R, S; - int pad_h, pad_w; - int stride_h, stride_w; - int dilation_h, dilation_w; - Mode mode; - - // Conv2d implementation-related parameters - int split_k_slices; - int groups; - - // - // Methods - // - -public: - CUTLASS_HOST_DEVICE - Conv2dProblemSize(): - N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), - pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), - mode(Mode::kConvolution), split_k_slices(1), groups(1) { } - - /// Constructor for default padding, stride, dilation, and split-K - CUTLASS_HOST_DEVICE - Conv2dProblemSize( - int N, - int H, - int W, - int C, - int P, - int Q, - int K, - int R, - int S, - Mode mode - ): - N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), - pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), - mode(mode), split_k_slices(1), groups (1) { } - - /// Constructor - CUTLASS_HOST_DEVICE - Conv2dProblemSize( - int N, - int H, - int W, - int C, - int K, - int R, - int S, - int P, - int Q, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - Mode mode, - int split_k_slices = 1, - int groups = 1 - ): - N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), - pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), - dilation_h(dilation_h), dilation_w(dilation_w), - mode(mode), split_k_slices(split_k_slices), groups (groups) { } - - /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord - // set user-defined output size and sets P and Q (include all data members in ctor) - CUTLASS_HOST_DEVICE - Conv2dProblemSize( - cutlass::Tensor4DCoord input_size, // NHWC - cutlass::Tensor4DCoord filter_size, // KRSC - cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ - cutlass::MatrixCoord stride, // stride_h, stride_w - cutlass::MatrixCoord dilation, // dilation_h, dilation_w - cutlass::Tensor4DCoord output_size, // NPQK - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), - P(output_size.h()), Q(output_size.w()), - K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - pad_h(padding[0]), pad_w(padding[2]), - stride_h(stride.row()), stride_w(stride.column()), - dilation_h(dilation.row()), dilation_w(dilation.column()), - mode(mode), split_k_slices(split_k_slices), groups(groups) {} - - /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord - // computes output size and sets P and Q (skip output from ctor arguments) - CUTLASS_HOST_DEVICE - Conv2dProblemSize( - cutlass::Tensor4DCoord input_size, // NHWC - cutlass::Tensor4DCoord filter_size, // KRSC - cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w - cutlass::MatrixCoord stride, // stride_h, stride_w - cutlass::MatrixCoord dilation, // dilation_h, dilation_w - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), - K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - pad_h(padding[0]), pad_w(padding[2]), - stride_h(stride.row()), stride_w(stride.column()), - dilation_h(dilation.row()), dilation_w(dilation.column()), - mode(mode), split_k_slices(split_k_slices), groups(groups) { - // set output P and Q - P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1; - Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1; - } - - /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord - // set user-defined output size and sets P and Q (skip padding, striding, and dilation) - CUTLASS_HOST_DEVICE - Conv2dProblemSize( - cutlass::Tensor4DCoord input_size, // NHWC - cutlass::Tensor4DCoord filter_size, // KRSC - cutlass::Tensor4DCoord output_size, // NPQK - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), - P(output_size.h()), Q(output_size.w()), - K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), - dilation_h(1), dilation_w(1), - mode(mode), split_k_slices(split_k_slices), groups(groups) {} - - // Reset covolution mode in the problem - CUTLASS_HOST_DEVICE - Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) { - Conv2dProblemSize tmp(*this); - tmp.mode = mode_; - return tmp; - } - - // Reset covolution mode in the problem - CUTLASS_HOST_DEVICE - Conv2dProblemSize reset_split_k_slices(int split_k_slices_) { - Conv2dProblemSize tmp(*this); - tmp.split_k_slices = split_k_slices_; - return tmp; - } - - /// Equality operator (ignores mode and split_k_slice) - CUTLASS_HOST_DEVICE - bool operator==(Conv2dProblemSize const &conv) const { - return ( - (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && - (K == conv.K) && (R == conv.R) && (S == conv.S) && - (P == conv.P) && (Q == conv.Q) && - (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && - (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && - (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) - ); - } - - /// Inequality operator - CUTLASS_HOST_DEVICE - bool operator!=(Conv2dProblemSize const &rhs) const { - return !(*this == rhs); - } - - /// Returns activation extent as Tensor4DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord activation_extent() const { - - return cutlass::Tensor4DCoord ({N, H, W, C}); - } - - /// Returns filter extent as Tensor4DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const { - - return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups}) - : cutlass::Tensor4DCoord ({K, R, S, C / groups}); - } - - /// Returns output extent as Tensor4DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord output_extent() const { - - return cutlass::Tensor4DCoord ({N, P, Q, K}); - } - - /// Returns activation size in number of elements - CUTLASS_HOST_DEVICE - int64_t activation_size() const { - - return static_cast(N) * static_cast(H) * - static_cast(W) * static_cast(C); - } - - /// Returns filter size in number of elements - CUTLASS_HOST_DEVICE - int64_t filter_size() const { - - return static_cast(K) * static_cast(R) * - static_cast(S) * static_cast(C) / - static_cast(groups); - } - - /// Returns output size in number of elements - CUTLASS_HOST_DEVICE - int64_t output_size() const { - - return static_cast(N) * static_cast(P) * - static_cast(Q) * static_cast(K); - } - - /// Returns padding as Tensor4DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord padding() const { - - return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w}); - } - - /// Returns stride as MatrixCoord - CUTLASS_HOST_DEVICE - cutlass::MatrixCoord stride() const { - - return cutlass::MatrixCoord ({stride_h, stride_w}); - } - - /// Returns dilation as MatrixCoord - CUTLASS_HOST_DEVICE - cutlass::MatrixCoord dilation() const { - - return cutlass::MatrixCoord ({dilation_h, dilation_w}); - } - - ///////////////////////////////////////////////////////////////// - // Methods used for strided dgrad implementation - ///////////////////////////////////////////////////////////////// - /// Number of filter r positions to accumulate in gemm-k dim - CUTLASS_HOST_DEVICE - int num_gemm_k_filter_r(int r) const { - return ((R - r + stride_h - 1) / stride_h); - } - - /// Number of filter s positions to accumulate in gemm-k dim - CUTLASS_HOST_DEVICE - int num_gemm_k_filter_s(int s) const { - return ((S - s + stride_w - 1) / stride_w); - } - - /// Number of filter positions to accumulate in gemm-k dim - CUTLASS_HOST_DEVICE - int num_gemm_k_filter_positions(int r, int s) const { - return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// ImplicitGemm helper functions // -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Determine the problem size of the implicit GEMM operation -CUTLASS_HOST_DEVICE -cutlass::gemm::GemmCoord implicit_gemm_problem_size( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - // Compute problem size - switch (conv_operator) { - case Operator::kFprop: - return gemm::GemmCoord( - problem_size.N * problem_size.P * problem_size.Q, - problem_size.K, - problem_size.R * problem_size.S * problem_size.C / problem_size.groups - ); - case Operator::kDeconv: - case Operator::kDgrad: - return gemm::GemmCoord( - problem_size.N * problem_size.H * problem_size.W, - problem_size.C, - problem_size.R * problem_size.S * problem_size.K - ); - case Operator::kWgrad: - return gemm::GemmCoord( - problem_size.K, - problem_size.R * problem_size.S * problem_size.C, - problem_size.N * problem_size.P * problem_size.Q - ); - default: - break; - } - return gemm::GemmCoord(); -} - -// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -CUTLASS_HOST_DEVICE -int implicit_gemm_k_iterations( - Operator conv_operator, - int threadblock_K, - Conv2dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, - GroupMode group_mode = GroupMode::kNone, - int threadblock_N = 0) { - - int iterations = 0; - - if (group_mode == GroupMode::kNone) { - - if (algorithm == IteratorAlgorithm::kFixedChannels) { - - int positions_per_iteration = threadblock_K / problem_size.C; - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; - break; - - default: - break; - } - } - else if (algorithm == IteratorAlgorithm::kFewChannels) { - - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; - break; - - default: - break; - } - } - else { - int elements_per_split_k_slice = 0; - - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kDeconv: - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; - - default: - break; - } - } - - } else if (group_mode == GroupMode::kDepthwise) { - int channels_per_cta = threadblock_N; - - if (algorithm == IteratorAlgorithm::kAnalytic) { - switch (conv_operator) { - case Operator::kFprop: - iterations = problem_size.R * problem_size.S * - ((channels_per_cta + threadblock_K - 1) / threadblock_K); - break; - - default: - break; - } - } - } else { // Group conv - - int channels_per_group = problem_size.C / problem_size.groups; - int k_per_group = problem_size.K / problem_size.groups; - - if (algorithm == IteratorAlgorithm::kAnalytic) { - switch (conv_operator) { - case Operator::kFprop: - iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); - // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups - if (problem_size.groups != 1) { - if (k_per_group < threadblock_N) { - iterations *= threadblock_N / k_per_group; - } - } - break; - - default: - break; - } - } else if (algorithm == IteratorAlgorithm::kOptimized) { - // Current optimized iterator only support GroupMode::kSingleGroup - if (group_mode == GroupMode::kSingleGroup) { - switch (conv_operator) { - case Operator::kFprop: - iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); - break; - - default: - break; - } - } - } - - } - - return iterations; -} - - -template -CUTLASS_HOST_DEVICE -int depthwise_gemm_k_iterations( - Operator conv_operator, - int threadblock_K, - Conv2dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, - GroupMode group_mode = GroupMode::kNone, - int threadblock_N = 0) { - - int n = problem_size.N; - int p = (problem_size.P + Output_P - 1) / Output_P; - int q = (problem_size.Q + Output_Q - 1) / Output_Q; - - int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - return iterations; -} - - -CUTLASS_HOST_DEVICE -int implicit_gemm_k_iterations_per_channel( - Operator conv_operator, - Conv2dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { - - int iterations = 0; //0 means not applicable - if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { - switch (conv_operator) { - case Operator::kFprop: - iterations = problem_size.R * problem_size.S; - break; - - case Operator::kDeconv: - case Operator::kDgrad: - iterations = problem_size.R * problem_size.S; - break; - - default: - break; - } - } - return iterations; -} - -//////////////////////////////////////////////////////////////////////////////// -// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -//////////////////////////////////////////////////////////////////////////////// -/// Returns ImplicitGemm tensor A extent as Tensor4DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); - default : break; - } - return cutlass::Tensor4DCoord(); -} - -/// Returns ImplicitGemm tensor B extent as Tensor4DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); - case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); - case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); - default : break; - } - return cutlass::Tensor4DCoord(); -} - -/// Returns ImplicitGemm tensor C extent as Tensor4DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); - default : break; - } - return cutlass::Tensor4DCoord(); -} - -/// Returns ImplicitGemm tensor A size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_a_size( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); - default : break; - } - return 0; -} - -/// Returns ImplicitGemm tensor B size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_b_size( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); - default : break; - } - return 0; -} - -/// Returns ImplicitGemm tensor C size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_c_size( - Operator conv_operator, - Conv2dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.output_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); - default : break; - } - return 0; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Strided dgrad helper functions // -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Returns number of CTAs tile M to cover valid MMAs per starting filter postion -CUTLASS_HOST_DEVICE -int strided_dgrad_tile_m_per_filter( - Conv2dProblemSize const &problem_size, - int tile_size_m) { - - // Compute NHW rows in Dx output that needs MMA per starting filter position - int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; - int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; - int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; - - // Number of CTAs tile M to cover valid MMAs per starting filter postion - int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; - - return tile_m_per_filter; -} - -// Computes starting Dx coord (h, w) for given starting filter postion -CUTLASS_HOST_DEVICE -void strided_dgrad_starting_coords( - Conv2dProblemSize const &problem_size, - FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, - int r, int s, - int &start_h, int &start_w) { - - // function locals for remainder by fast divmod - int pad_h_rem_, pad_w_rem_; - - // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; - stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); - int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); - stride_h_divmod.divmod(start_h, r_); - - //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; - stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); - int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); - stride_w_divmod.divmod(start_w, s_); -} - -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h deleted file mode 100644 index 48bf056e17014400a6bc41b87193a05de3cb9c96..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h +++ /dev/null @@ -1,519 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief This file contains definitions and utility functions for describing convolution problem sizes. - - Conv3dProblem desciption: - activation (NDHWC), - filter (KTRSC), - output (NZPQK), - pading (pad_d, pad_h, pad_w), - stride (stride_d, stride_h, stride_w), - dilation (dilation_d, dilation_h, dilation_w). - - Free functions to map: - Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) - Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) - Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -*/ - -#pragma once - -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -namespace cutlass { -namespace conv { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Problem size structure -struct Conv3dProblemSize : public Conv2dProblemSize { - // - // Type definitions - // - - // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions - using Coord3D = Coord<3>; - - // - // Data members - // - - // Conv3d strictly problem size parameters - int D, T, Z; // input depth, filter depth, output depth - int pad_d; // padding in depth dimension - int stride_d; // stride in depth dimension - int dilation_d; // dilation in depth dimension - - // - // Methods - // -public: - CUTLASS_HOST_DEVICE - Conv3dProblemSize(): - Conv2dProblemSize(), - D(0), T(0), Z(0), - pad_d(0), - stride_d(1), - dilation_d(1) { } - - /// Constructor for default padding, stride, dilation, and split-K - CUTLASS_HOST_DEVICE - Conv3dProblemSize( - int N, - int D, - int H, - int W, - int C, - int Z, - int P, - int Q, - int K, - int T, - int R, - int S, - Mode mode - ): - Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode), - D(D), T(T), Z(Z), - pad_d(T / 2), stride_d(1), dilation_d(1) { } - - /// Constructor - CUTLASS_HOST_DEVICE - Conv3dProblemSize( - int N, - int D, - int H, - int W, - int C, - int K, - int T, - int R, - int S, - int Z, - int P, - int Q, - int pad_d, - int pad_h, - int pad_w, - int stride_d, - int stride_h, - int stride_w, - int dilation_d, - int dilation_h, - int dilation_w, - Mode mode, - int split_k_slices = 1, - int groups = 1 - ): - Conv2dProblemSize( - N, H, W, C, K, R, S, P, Q, - pad_h, pad_w, - stride_h, stride_w, - dilation_h, dilation_w, - mode, split_k_slices, groups), - D(D), T(T), Z(Z), - pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { } - - /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D - // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) - CUTLASS_HOST_DEVICE - Conv3dProblemSize( - cutlass::Tensor5DCoord input_size, // NDHWC - cutlass::Tensor5DCoord filter_size, // KTRSC - Coord3D padding, // pad_d, pad_h, pad_w - Coord3D stride, // stride_d, stride_h, stride_w - Coord3D dilation, // dilation_d, dilation_h, dilation_w - cutlass::Tensor5DCoord output_size, // NZPQK - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - Conv2dProblemSize( - {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, - {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, - {padding[1], padding[1], padding[2], padding[2]}, - {stride[1], stride[2]}, - {dilation[1], dilation[2]}, - {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, - mode, split_k_slices, groups), - D(input_size.d()), T(filter_size.d()), Z(output_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { } - - /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D - // *computes* output size and sets Z, P and Q (include all data members in ctor) - CUTLASS_HOST_DEVICE - Conv3dProblemSize( - cutlass::Tensor5DCoord input_size, // NDHWC - cutlass::Tensor5DCoord filter_size, // KTRSC - Coord3D padding, // pad_d, pad_h, pad_w - Coord3D stride, // stride_d, stride_h, stride_w - Coord3D dilation, // dilation_d, dilation_h, dilation_w - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - Conv2dProblemSize( - {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, - {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, - {padding[1], padding[1], padding[2], padding[2]}, - {stride[1], stride[2]}, - {dilation[1], dilation[2]}, - mode, split_k_slices, groups), - D(input_size.d()), T(filter_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) - { - // set output Z - Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; - } - - /// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D - // *computes* output size and sets Z, P and Q (include all data members in ctor) - CUTLASS_HOST_DEVICE - Conv3dProblemSize( - cutlass::Tensor5DCoord input_size, // NDHWC - cutlass::Tensor5DCoord filter_size, // KTRSC - CUTLASS_STL_NAMESPACE::tuple padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q - Coord3D stride, // stride_d, stride_h, stride_w - Coord3D dilation, // dilation_d, dilation_h, dilation_w - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, - int split_k_slices = 1, - int groups = 1 - ): - Conv2dProblemSize( - {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, - {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, - {CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1], - CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]}, - {stride[1], stride[2]}, - {dilation[1], dilation[2]}, - mode, split_k_slices, groups), - D(input_size.d()), T(filter_size.d()), - pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0]) - { - // set output Z - Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1; - } - - /// Equality operator (ignores mode and split_k_slice) - CUTLASS_HOST_DEVICE - bool operator==(Conv3dProblemSize const &conv) const { - return ( - (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) && - (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && - (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && - (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && - (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && - (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) - ); - } - - /// Inequality operator - CUTLASS_HOST_DEVICE - bool operator!=(Conv3dProblemSize const &rhs) const { - return !(*this == rhs); - } - - // Reset covolution mode in the problem - CUTLASS_HOST_DEVICE - Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) { - Conv3dProblemSize tmp(*this); - tmp.mode = mode_; - return tmp; - } - - // Reset covolution mode in the problem - CUTLASS_HOST_DEVICE - Conv3dProblemSize reset_split_k_slices(int split_k_slices_) { - Conv3dProblemSize tmp(*this); - tmp.split_k_slices = split_k_slices_; - return tmp; - } - - /// Returns activation extent as Tensor5DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord activation_extent() const { - - return cutlass::Tensor5DCoord ({N, D, H, W, C}); - } - - /// Returns filter extent as Tensor5DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const { - - return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K}) - : cutlass::Tensor5DCoord ({K, T, R, S, C}); - } - - /// Returns output extent as Tensor5DCoord - CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord output_extent() const { - - return cutlass::Tensor5DCoord ({N, Z, P, Q, K}); - } - - /// Returns activation size in number of elements - CUTLASS_HOST_DEVICE - int64_t activation_size() const { - - return static_cast(N) * static_cast(D) * - static_cast(H) * static_cast(W) * - static_cast(C); - } - - /// Returns filter size in number of elements - CUTLASS_HOST_DEVICE - int64_t filter_size() const { - - return static_cast(K) * static_cast(T) * - static_cast(R) * static_cast(S) * - static_cast(C); - } - - /// Returns output size in number of elements - CUTLASS_HOST_DEVICE - int64_t output_size() const { - - return static_cast(N) * static_cast(Z) * - static_cast(P) * static_cast(Q) * - static_cast(K); - } - - /// Returns padding as Coord3D - CUTLASS_HOST_DEVICE - Coord3D padding() const { - - return Coord3D ({pad_d, pad_h, pad_w}); - } - - /// Returns stride as MatrixCoord - CUTLASS_HOST_DEVICE - Coord3D stride() const { - - return Coord3D ({stride_d, stride_h, stride_w}); - } - - /// Returns dilation as MatrixCoord - CUTLASS_HOST_DEVICE - Coord3D dilation() const { - - return Coord3D ({dilation_d, dilation_h, dilation_w}); - } - -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// ImplicitGemm helper functions // -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Determine the problem size of the implicit GEMM operation -CUTLASS_HOST_DEVICE -cutlass::gemm::GemmCoord implicit_gemm_problem_size( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - // Compute problem size - switch (conv_operator) { - case Operator::kFprop: - return gemm::GemmCoord( - problem_size.N * problem_size.Z * problem_size.P * problem_size.Q, - problem_size.K, - problem_size.T * problem_size.R * problem_size.S * problem_size.C - ); - case Operator::kDeconv: - case Operator::kDgrad: - return gemm::GemmCoord( - problem_size.N * problem_size.D * problem_size.H * problem_size.W, - problem_size.C, - problem_size.T * problem_size.R * problem_size.S * problem_size.K - ); - case Operator::kWgrad: - return gemm::GemmCoord( - problem_size.K, - problem_size.T * problem_size.R * problem_size.S * problem_size.C, - problem_size.N * problem_size.Z * problem_size.P * problem_size.Q - ); - default: - break; - } - return gemm::GemmCoord(); -} - -// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -CUTLASS_HOST_DEVICE -int implicit_gemm_k_iterations( - Operator conv_operator, - int threadblock_K, - Conv3dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, - GroupMode group_mode = GroupMode::kNone, - int threadblock_N = 0) { - - int iterations = 0; - int elements_per_split_k_slice = 0; - if (group_mode == GroupMode::kNone) { - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kDeconv: - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; - - default: - break; - } - } else if (group_mode == GroupMode::kDepthwise) { - int channels_per_cta = threadblock_N; - - if (algorithm == IteratorAlgorithm::kAnalytic) { - switch (conv_operator) { - case Operator::kFprop: - iterations = problem_size.T * problem_size.R * problem_size.S * - ((channels_per_cta + threadblock_K - 1) / threadblock_K); - break; - - default: - break; - } - } - } - - return iterations; -} - -//////////////////////////////////////////////////////////////////////////////// -// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -//////////////////////////////////////////////////////////////////////////////// -/// Returns ImplicitGemm tensor A extent as Tensor5DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); - default : break; - } - return cutlass::Tensor5DCoord(); -} - -/// Returns ImplicitGemm tensor B extent as Tensor5DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); - case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); - case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); - default : break; - } - return cutlass::Tensor5DCoord(); -} - -/// Returns ImplicitGemm tensor C extent as Tensor5DCoord -CUTLASS_HOST_DEVICE -cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); - case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); - default : break; - } - return cutlass::Tensor5DCoord(); -} - -/// Returns ImplicitGemm tensor A size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_a_size( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); - default : break; - } - return 0; -} - -/// Returns ImplicitGemm tensor B size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_b_size( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); - default : break; - } - return 0; -} - -/// Returns ImplicitGemm tensor C size in number of elements -CUTLASS_HOST_DEVICE -int64_t implicit_gemm_tensor_c_size( - Operator conv_operator, - Conv3dProblemSize const &problem_size) { - switch (conv_operator) { - case cutlass::conv::Operator::kFprop: return problem_size.output_size(); - case cutlass::conv::Operator::kDeconv: - case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); - case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); - default : break; - } - return 0; -} - -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp deleted file mode 100644 index 3c31c21b2508914d10d41bb865a6da145bf3c106..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp +++ /dev/null @@ -1,601 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief This file contains definitions and utility functions for describing convolution problem shapes. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/conv/convolution.h" - -#include "cute/container/array.hpp" - -#if ! defined(__CUDACC_RTC__) -#include -#endif - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion. -// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK) -// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. -template < - conv::Operator ConvOp_, - int NumSpatialDimensions_ -> -struct ConvProblemShape { - // - // Alias types for members - // - - static constexpr int RankS = NumSpatialDimensions_; - static constexpr int RankT = NumSpatialDimensions_ + 2; - static constexpr conv::Operator ConvOp = ConvOp_; - static constexpr int NumSpatialDimensions = NumSpatialDimensions_; - using SpatialExtent = cute::array; - using TensorExtent = cute::array; - using TensorStride = cute::array; - using ShapePadding = SpatialExtent; - using TraversalStride = SpatialExtent; - using ShapeDilation = SpatialExtent; - using Corner = SpatialExtent; - - // - // Members - // - cutlass::conv::Mode mode{}; - TensorExtent shape_A{}; - TensorStride stride_A{}; - TensorExtent shape_B{}; - TensorStride stride_B{}; - TensorExtent shape_C{}; - TensorStride stride_C{}; - - // asymmetric padding, both upper and lower padding must be >= 0 - ShapePadding lower_padding{}; - ShapePadding upper_padding{}; - TraversalStride traversal_stride{}; - ShapeDilation dilation{}; - int groups = 1; - - // - // Methods - // - - ConvProblemShape() = default; - - // Constructor accepts user facing arguments and computes to stores the corners as its internal state - ConvProblemShape( - conv::Mode mode, // convolution/cross-correlation - TensorExtent shape_act, // [n,d,h,w,c] - TensorStride stride_act, // [n,d,h,w,c] - TensorExtent shape_flt, // [k,t,r,s,c] - TensorStride stride_flt, // [k,t,r,s,c] - ShapePadding lower_padding, // [pad_d, pad_h, pad_w] - ShapePadding upper_padding, // [pad_d, pad_h, pad_w] - TraversalStride tstride, // [stride_d, stride_h, stride_w] - ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] - int groups) - : mode(mode) - , lower_padding(lower_padding) - , upper_padding(upper_padding) - , traversal_stride(tstride) - , dilation(dilation) - , groups(groups) { - - auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); - set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); - } - - // Allow user input of xformed activation stride to support non-packed strides. - ConvProblemShape( - conv::Mode mode, // convolution/cross-correlation - TensorExtent shape_act, // [n,d,h,w,c] - TensorStride stride_act, // [n,d,h,w,c] - TensorExtent shape_flt, // [k,t,r,s,c] - TensorStride stride_flt, // [k,t,r,s,c] - TensorStride stride_xformed_act, // [n,z,p,q,k] - ShapePadding lower_padding, // [pad_d, pad_h, pad_w] - ShapePadding upper_padding, // [pad_d, pad_h, pad_w] - TraversalStride tstride, // [stride_d, stride_h, stride_w] - ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] - int groups) - : mode(mode) - , lower_padding(lower_padding) - , upper_padding(upper_padding) - , traversal_stride(tstride) - , dilation(dilation) - , groups(groups) { - - CUTLASS_ASSERT(stride_act[RankT - 1] == 1); - CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); - CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); - - auto stride_act_packed = packed_stride_right_major(shape_act); - auto stride_flt_packed = packed_stride_right_major(shape_flt); - auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); - - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < RankT - 1; ++i) { - CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); - CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); - CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); - } - - set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); - } - - // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order. - ConvProblemShape( - conv::Mode mode, - TensorExtent shape_act, - TensorExtent shape_flt, - ShapePadding lower_padding, - ShapePadding upper_padding, - TraversalStride tstride, - ShapeDilation dilation, - int groups) - : ConvProblemShape( - mode, - shape_act, - packed_stride_right_major(shape_act), - shape_flt, - packed_stride_right_major(shape_flt), - lower_padding, - upper_padding, - tstride, - dilation, - groups) { - } - -#if ! defined(__CUDACC_RTC__) - // Constructor accepts user facing arguments and computes to stores the corners as its internal state - ConvProblemShape( - conv::Mode mode, - std::initializer_list shape_act_, - std::initializer_list stride_act_, - std::initializer_list shape_flt_, - std::initializer_list stride_flt_, - std::initializer_list lower_padding_, - std::initializer_list upper_padding_, - std::initializer_list traversal_stride_, - std::initializer_list dilation_, - int groups) - : mode(mode) - , groups(groups) { - - TensorExtent shape_act{}; - TensorStride stride_act{}; - TensorExtent shape_flt{}; - TensorStride stride_flt{}; - - assert(shape_act_.size() == shape_act.size()); - assert(stride_act_.size() == stride_act.size()); - assert(shape_flt_.size() == shape_flt.size()); - assert(stride_flt_.size() == stride_flt.size()); - assert(lower_padding_.size() == lower_padding.size()); - assert(upper_padding_.size() == upper_padding.size()); - assert(traversal_stride_.size() == traversal_stride.size()); - assert(dilation_.size() == dilation.size()); - - std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); - std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); - std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); - std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); - std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); - std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); - std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); - std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); - - auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); - set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); - } - - // Allow user input of xformed activation stride to support non-packed strides. - ConvProblemShape( - conv::Mode mode, - std::initializer_list shape_act_, - std::initializer_list stride_act_, - std::initializer_list shape_flt_, - std::initializer_list stride_flt_, - std::initializer_list stride_xformed_act_, - std::initializer_list lower_padding_, - std::initializer_list upper_padding_, - std::initializer_list traversal_stride_, - std::initializer_list dilation_, - int groups) - : mode(mode) - , groups(groups) { - TensorExtent shape_act{}; - TensorStride stride_act{}; - TensorExtent shape_flt{}; - TensorStride stride_flt{}; - TensorStride stride_xformed_act{}; - - std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); - std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); - std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); - std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); - std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin()); - std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); - std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); - std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); - std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); - - CUTLASS_ASSERT(stride_act[RankT - 1] == 1); - CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); - CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); - - auto stride_act_packed = packed_stride_right_major(shape_act); - auto stride_flt_packed = packed_stride_right_major(shape_flt); - auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); - - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < RankT - 1; ++i) { - CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); - CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); - CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); - } - - set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); - } - - // Constructor accepts user facing arguments and computes to stores the corners as its internal state - ConvProblemShape( - conv::Mode mode, - std::initializer_list shape_act_, - std::initializer_list shape_flt_, - std::initializer_list lower_padding_, - std::initializer_list upper_padding_, - std::initializer_list traversal_stride_, - std::initializer_list dilation_, - int groups) - : mode(mode) - , groups(groups) { - TensorExtent shape_act{}; - TensorStride stride_act{}; - TensorExtent shape_flt{}; - TensorStride stride_flt{}; - - assert(shape_act_.size() == shape_act.size()); - assert(shape_flt_.size() == shape_flt.size()); - assert(lower_padding_.size() == lower_padding.size()); - assert(upper_padding_.size() == upper_padding.size()); - assert(traversal_stride_.size() == traversal_stride.size()); - assert(dilation_.size() == dilation.size()); - - std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); - std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); - std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); - std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); - std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); - std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); - stride_act = packed_stride_right_major(shape_act); - stride_flt = packed_stride_right_major(shape_flt); - - auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); - set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); - } -#endif // not defined(__CUDACC_RTC__) - - // Set shape and stride of tensor A/B/C according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ------ | ------ | ------ | ------| - // | ShapeA | NDHWC | NZPQK | NZPQK | - // | ShapeB | KTRSC | KTRSC | NDHWC | - // | ShapeC | NZPQK | NDHWC | KTRSC | - // - // Input comes from calculate_xformed_act, which does NOT depend on ConvOp. - CUTLASS_HOST_DEVICE - constexpr void - set_shape_stride_ABC( - TensorExtent shape_act, - TensorStride stride_act, - TensorExtent shape_flt, - TensorStride stride_flt, - TensorExtent shape_xformed_act, - TensorStride stride_xformed_act) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - printf("*** set_shape_stride_ABC ***"); - printf("\n shape_act: "); - print(shape_act); - printf("\n stride_act: "); - print(stride_act); - printf("\n shape_flt: "); - print(shape_flt); - printf("\n stride_flt: "); - print(stride_flt); - printf("\n shape_xformed_act: "); - print(shape_xformed_act); - printf("\n stride_xformed_act: "); - print(stride_xformed_act); - if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { - printf("\n ConvOp: Fprop"); - } - if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { - printf("\n ConvOp: Dgrad"); - } - if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { - printf("\n ConvOp: Wgrad"); - } - printf("\n"); -#endif - - if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { - shape_A = shape_act; - stride_A = stride_act; - shape_B = shape_flt; - stride_B = stride_flt; - shape_C = shape_xformed_act; - stride_C = stride_xformed_act; - } - else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { - shape_A = shape_xformed_act; - stride_A = stride_xformed_act; - shape_B = shape_flt; - stride_B = stride_flt; - shape_C = shape_act; - stride_C = stride_act; - } - else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { - shape_A = shape_xformed_act; - stride_A = stride_xformed_act; - shape_B = shape_act; - stride_B = stride_act; - shape_C = shape_flt; - stride_C = stride_flt; - } -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - printf("\n shape_A: "); - print(shape_A); - printf("\n stride_A: "); - print(stride_A); - printf("\n shape_B: "); - print(shape_B); - printf("\n stride_B: "); - print(stride_B); - printf("\n shape_C: "); - print(shape_C); - printf("\n stride_C: "); - print(stride_C); -#endif - } - - // Get A extents. - // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) - // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) - // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) - CUTLASS_HOST_DEVICE - constexpr auto - get_shape_A() const { - using cute::make_shape; - using cute::take; - - if constexpr (ConvOp == conv::Operator::kFprop || - ConvOp == conv::Operator::kDgrad) { - return make_shape( - cute::reverse(take<0, RankT - 1>(shape_A)), - shape_A[RankT - 1]); - } - // For wgrad kernel, we need to linearize NZPQ for tensor A - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return make_shape( - shape_A[RankT - 1], - cute::product(take<0, RankT - 1>(shape_A))); - } - } - - // Get B extents. - // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) - // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) - // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) - CUTLASS_HOST_DEVICE - constexpr auto - get_shape_B() const { - using cute::make_shape; - using cute::reverse; - using cute::take; - - if constexpr (ConvOp == conv::Operator::kFprop) { - return make_shape( - shape_B[0], - reverse(take<1, RankT>(shape_B))); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return make_shape( - shape_B[RankT - 1], - reverse(take<0, RankT - 1>(shape_B))); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)] - return make_shape( - shape_B[RankT - 1], - cute::insert<0>( - reverse(take<1, RankT - 1>(shape_B)), - shape_B[0])); - } - } - - // Get C extents. - // fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) - // dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) - // wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) - CUTLASS_HOST_DEVICE - constexpr auto - get_shape_C() const { - using cute::make_shape; - using cute::reverse; - using cute::take; - - if constexpr (ConvOp == conv::Operator::kFprop || - ConvOp == conv::Operator::kDgrad) { - return make_shape( - reverse(take<0, RankT - 1>(shape_C)), - shape_C[RankT - 1]); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return make_shape( - shape_C[0], - reverse(take<1, RankT>(shape_C))); - } - } - - // Static method that returns the canonical strides of tensors (layouts are right major and compact) - CUTLASS_HOST_DEVICE - static constexpr TensorStride - packed_stride_right_major(TensorExtent const& extents) { - TensorStride strides{}; - strides[RankT-1] = 1; - cute::for_each(cute::make_rseq{}, [&](auto i) { - strides[i] = extents[i+1] * strides[i+1]; - }); - return strides; - } - - // Static method that returns the packed logical size of any TensorExtent - CUTLASS_HOST_DEVICE - static constexpr size_t - size(TensorExtent const& extents) { - size_t size = 1; - cute::for_each(cute::make_seq{}, [&](auto i) { - size *= extents[i]; - }); - return size; - } - - CUTLASS_HOST_DEVICE - constexpr size_t - size_A() const { - return shape_A[0] * stride_A[0]; - } - - CUTLASS_HOST_DEVICE - constexpr size_t - size_B() const { - return shape_B[0] * stride_B[0]; - } - - CUTLASS_HOST_DEVICE - constexpr size_t - size_C() const { - return shape_C[0] * stride_C[0]; - } - - // Equality operator - CUTLASS_HOST_DEVICE - bool operator==(ConvProblemShape const& rhs) const { - using cute::for_each; - using cute::make_seq; - - bool is_equal = true; - - // Compare all tensor extents - for_each(make_seq{}, [&](auto i) { - is_equal = is_equal - && (shape_A[i] == rhs.shape_A[i]) - && (shape_B[i] == rhs.shape_B[i]); - }); - - // Compare all spatial extents - for_each(make_seq{}, [&](auto i) { - is_equal = is_equal - && (lower_padding[i] == rhs.lower_padding[i]) - && (upper_padding[i] == rhs.upper_padding[i]) - && (traversal_stride[i] == rhs.traversal_stride[i]) - && (dilation[i] == rhs.dilation[i]); - }); - - return is_equal; - } - - /// Inequality operator - CUTLASS_HOST_DEVICE - bool operator!=(ConvProblemShape const &rhs) const { - return !(*this == rhs); - } - -private: - CUTLASS_HOST_DEVICE - constexpr auto - calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) { - TensorExtent shape_xformed_act{}; - // calculate n,z,p,q,k. - // a helper lambda to compute a single spatial extent of the nzpqk tensor - auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { - return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; - }; - - shape_xformed_act[0] = shape_act[0]; // Activation N extent - cute::for_each(cute::make_seq{}, [&](auto i) { - shape_xformed_act[i+1] = nzpqk_extent( - shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]); - }); - shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent - - TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act); - - return cute::make_tuple(shape_xformed_act, stride_xformed_act); - } -}; - -template< - conv::Operator ConvOp, - int SpatialDim -> -void print(ConvProblemShape const& problem) { - printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n", - SpatialDim, int(ConvOp)); - printf("\tTensorA: "); - cute::print(problem.shape_A); printf(":"); - cute::print(problem.stride_A); printf("\n"); - printf("\tTensorB: "); - cute::print(problem.shape_B); printf(":"); - cute::print(problem.stride_B); printf("\n"); - printf("\tTensorC: "); - cute::print(problem.shape_C); printf(":"); - cute::print(problem.stride_C); printf("\n"); - printf("\tLower padding: "); print(problem.lower_padding); printf("\n"); - printf("\tUpper padding: "); print(problem.upper_padding); printf("\n"); - printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n"); - printf("\tDilation: "); print(problem.dilation); printf("\n"); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h deleted file mode 100644 index a3cc98b4740115aefd557468d01ad28fa9a1028a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h +++ /dev/null @@ -1,194 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief - -This file contains definitions and utility functions for describing convolution problem sizes in terms of -activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and -dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm -tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types. - - * Mapping convolutions to Gemm computation * - -Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm -(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output. -The underlying gemm operation follows the standard gemm definition: - - C = A * B + C - - A and B are input matrices - C is source and output matrix - - -For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped -to convolution tensors Activation, Filter and Output as described in the table below. - - ___________________________________________________________________________ - ConvolutionalOperator | A | B | C - ___________________________________________________________________________ - | | | | | - | Fprop | Activation | Filter | Output | - | Dgrad | Output | Filter | Activation | - | Wgrad | Output | Activation | Filter | - ___________________________________________________________________________ - -In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output). - -For example, it's confusing and error prone to document a convolution class or function -as operating on "A, B, Output." Instead, use the mapping functions below, -and adhere to using either A, B, C or Activation, Filter, Output. - -Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap -Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm_enumerated_types.h" -#include "cutlass/matrix_coord.h" - -namespace cutlass { -namespace conv { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Convolutional operator -enum class Operator { - kFprop, - kDgrad, - kWgrad, - kDeconv -}; - -/// Distinguishes convolution from cross correlation -enum class Mode { - kCrossCorrelation, - kConvolution -}; - -/// Selects among several implementation variants trading off performance with simplicity -enum class IteratorAlgorithm { - kAnalytic, ///< functionally correct in all cases but lower performance - kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad - kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) - kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) - kFixedStrideDilation ///< Optimized for fixed stride and dilation -}; - -/// Distinguishes among partial specializations that accelerate certain problems where convolution -/// stride is unit. -enum class StrideSupport { - kStrided, ///< arbitrary convolution stride - kUnity, ///< unit convolution stride - kFixed ///< fixed convolution stride -}; - -/// Identifies split-K mode -enum class SplitKMode { - kNone, - kSerial, - kParallel -}; - -/// Identifies group mode -enum class GroupMode { - kNone, - kSingleGroup, ///< One CTA calculates one group or less - kMultipleGroup, ///< One CTA calculates multiple groups - kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Shape of a tensor -template < - int N = 1, - int H = 1, - int W = 1, - int C = 1 -> -struct TensorNHWCShape { - static int const kN = N; - static int const kH = H; - static int const kW = W; - static int const kC = C; - - static int const kHW = H * W; - static int const kNHW = N * kHW; - static int const kNHWC = N * H * W * C; - - static int const kCount = kNHWC; - - // - // Static member functions - // - - /// Returns a Coord object - CUTLASS_HOST_DEVICE - static Coord<4> toCoord() { - return make_Coord(kN, kH, kW, kC); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Shape of a conv2d stride, which controls how the filter convolves around the input volume -template < - /// Stride in horizontal direction - int u = 1, - /// Stride in vertical direction - int v = 1 -> -struct Stride2D { - static int const kU = u; - static int const kV = v; - - // - // Static member functions - // - - /// Returns a Coord object - CUTLASS_HOST_DEVICE - static Coord<2> toCoord() { - return make_Coord(kU, kV); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp deleted file mode 100644 index 0802921d60ce1809a7da67805de0f045c3511b19..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp +++ /dev/null @@ -1,137 +0,0 @@ - -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/conv/convnd_problem_shape.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// - - // Helper function to get the problem shape -template -auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) { - return T::get_problem_shape_MNKL(problem_shape); -} - -template -ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) { - return problem_shape; -} - -// Get problem shape MNKL according to following table: -// | | Fprop | Dgrad | Wgrad | -// | ---- | --------- | -------- | -------- | -// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | -// | Shape_N | (K) | (C) | (C,S,R,T) | -// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | -// | Shape_L | _1 | (V,U,O) | _1 | - -template -CUTLASS_HOST_DEVICE -constexpr auto -get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) { - return problem_shape; -} - - -template -CUTLASS_HOST_DEVICE -constexpr auto -get_transformed_problem_shape_MNKL(ConvProblemShape const& problem_shape) { - using cute::insert; - using cute::make_shape; - using cute::reverse; - using cute::take; - - constexpr int RankT = SpatialDim + 2; - - if constexpr (ConvOp == conv::Operator::kWgrad) { - auto M_xformed = problem_shape.shape_C[0]; - auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C)); - auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kFprop){ - auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C)); - auto N_xformed = problem_shape.shape_C[RankT - 1]; - auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O) - auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed); - auto N_xformed = problem_shape.shape_C[RankT - 1]; - // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] - auto K_xformed = insert<0>( - (reverse(take<1,RankT - 1>(problem_shape.shape_B))), - problem_shape.shape_B[0]); - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } -} - -// Assuming im2col linearization -// Get problem shape MNKL according to following table: -// | | Fprop | Dgrad | Wgrad | -// | ---- | --------- | -------- | -------- | -// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | -// | Shape_N | (K) | (C) | (C,S,R,T) | -// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | -// | Shape_L | _1 | (V*U*O) | _1 | -template -CUTLASS_HOST_DEVICE -constexpr auto -get_linearized_problem_shape_MNKL(ConvProblemShape const& problem_shape) { - - auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape); - - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - return cute::make_shape(cute::product(M), N, K, cute::product(L)); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return cute::make_shape(M, N, cute::product(K), L); - } - -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::detail - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp deleted file mode 100644 index d60469f429f94f4b8152a02d9db232eea5698e56..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp +++ /dev/null @@ -1,448 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -// common -#include "cutlass/arch/mma.h" -#include "cutlass/cutlass.h" -#include "cutlass/arch/mma.h" -#include "cutlass/trace.h" -#include "cutlass/cluster_launch.hpp" -#include "cutlass/device_kernel.h" - -#include "cutlass/conv/kernel/conv_universal.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/detail/layout.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::device { - -//////////////////////////////////////////////////////////////////////////////// - -/*! - ConvUniversalAdapter is a stateful, reusable handle built around a kernel - of type cutlass::conv::kernel::ConvUniversal. - - It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs - to create it from the host facing arguments. For power users, static methods - are exposed that bypass the stateful methods or args->params lowering. -*/ -template -class ConvUniversalAdapter -{ -public: - using ConvKernel = GetUnderlyingKernel_t; - using TileShape = typename ConvKernel::TileShape; - using ElementA = typename ConvKernel::ElementA; - using ElementB = typename ConvKernel::ElementB; - using ElementC = typename ConvKernel::ElementC; - using ElementD = typename ConvKernel::ElementD; - using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC; - using DispatchPolicy = typename ConvKernel::DispatchPolicy; - using CollectiveMainloop = typename ConvKernel::CollectiveMainloop; - using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue; - - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - // Tease out meta-information about the conv algorithm - static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; - static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; - - // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! - using OperatorClass = cute::conditional_t< - (cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1), - cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; - - using ArchTag = typename ConvKernel::ArchTag; - - // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape - using ThreadblockShape = cutlass::gemm::GemmShape< - cute::size<0>(TileShape{}), - cute::size<1>(TileShape{}), - cute::size<2>(TileShape{})>; - - using ClusterShape = cutlass::gemm::GemmShape< - cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>; - - // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape - using InstructionShape = cutlass::gemm::GemmShape< - cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), - cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), - cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; - - // Legacy: provide a correct warp count, but no reliable warp shape - static int const kThreadCount = ConvKernel::MaxThreadsPerBlock; - - // Warp shape is not a primary API type in 3.x - // But we can best approximate it by inspecting the TiledMma - // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K - // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads - static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32); - static constexpr int WarpsInMmaM = 4; - static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); - using WarpCount = cutlass::gemm::GemmShape; - using WarpShape = cutlass::gemm::GemmShape< - CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, - CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, - CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; - - static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; - - // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); - static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); - static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); - - using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - /// Argument structure: User API - using Arguments = typename ConvKernel::Arguments; - /// Argument structure: Kernel API - using Params = typename ConvKernel::Params; - -private: - - /// Kernel API parameters object - Params params_; - -public: - - /// Access the Params structure - Params const& params() const { - return params_; - } - - /// Determines whether the conv can execute the given problem. - static Status - can_implement(Arguments const& args) { - if (ConvKernel::can_implement(args)) { - return Status::kSuccess; - } - else { - return Status::kInvalid; - } - } - - /// Gets the workspace size - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_bytes = 0; - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - workspace_bytes += ConvKernel::get_workspace_size(args); - return workspace_bytes; - } - - /// Computes the grid shape - static dim3 - get_grid_shape(Arguments const& args, void* workspace = nullptr) { - auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace); - return ConvKernel::get_grid_shape(tmp_params); - } - - /// Computes the grid shape - static dim3 - get_grid_shape(Params const& params) { - return ConvKernel::get_grid_shape(params); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int /* smem_capacity */ = -1) { - CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()"); - int max_active_blocks = -1; - int smem_size = ConvKernel::SharedStorageSize; - - // first, account for dynamic smem capacity if needed - cudaError_t result; - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST( - " cudaFuncSetAttribute() returned error: " - << cudaGetErrorString(result)); - return -1; - } - } - - // query occupancy after setting smem size - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - device_kernel, - ConvKernel::MaxThreadsPerBlock, - smem_size); - - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " - << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - - /// Initializes conv state from arguments. - Status - initialize( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - - CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Initialize the workspace - Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter); - if (status != Status::kSuccess) { - return status; - } - - // Initialize the Params structure - params_ = ConvKernel::to_underlying_arguments(args, workspace); - - // Don't set the function attributes - require the CudaHostAdapter to set it. - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - return Status::kSuccess; - } - else { - // account for dynamic smem capacity if needed - int smem_size = ConvKernel::SharedStorageSize; - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - cudaError_t result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - } - return Status::kSuccess; - } - - /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. - Status - update(Arguments const& args, void* workspace = nullptr) { - CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace); - - size_t workspace_bytes = get_workspace_size(args); - if (workspace_bytes > 0 && nullptr == workspace) { - return Status::kErrorWorkspaceNull; - } - - params_ = ConvKernel::to_underlying_arguments(args, workspace); - return Status::kSuccess; - } - - /// Primary run() entry point API that is static allowing users to create and manage their own params. - /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments() - static Status - run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - CUTLASS_TRACE_HOST("ConvUniversal::run()"); - dim3 const block = ConvKernel::get_block_shape(); - dim3 const grid = get_grid_shape(params); - - // configure smem size and carveout - int smem_size = ConvKernel::SharedStorageSize; - - Status launch_result; - // Use extended launch API only for mainloops that use it - if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { - [[maybe_unused]] constexpr bool is_static_1x1x1 = - cute::is_static_v and - cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; - dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); - // Dynamic cluster support - [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; - if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || - ConvKernel::ArchTag::kMinComputeCapability == 101) { - if constexpr (!cute::is_static_v) { - fallback_cluster = params.hw_info.cluster_shape_fallback; - cluster = params.hw_info.cluster_shape; - } - } - - void* kernel_params[] = {¶ms}; - if constexpr (kEnableCudaHostAdapter) { - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - - launch_result = cuda_adapter->launch(grid, - cluster, - fallback_cluster, - block, - smem_size, - stream, - kernel_params, - kernel_index); - } - else { - return Status::kErrorInternal; - } - } - else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - void const* kernel = (void const*) device_kernel; - if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90 - || ConvKernel::ArchTag::kMinComputeCapability == 100 - ) { - if constexpr (is_static_1x1x1) { - device_kernel<<>>(params); - launch_result = Status::kSuccess; - } - else { - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params); - } - } - else { - if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || - ConvKernel::ArchTag::kMinComputeCapability == 101) { - launch_result = ClusterLauncher::launch_with_fallback_cluster( - grid, - cluster, - fallback_cluster, - block, - smem_size, - stream, - kernel, - kernel_params); - } - } - } - } - else { - launch_result = Status::kSuccess; - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - void* kernel_params[] = {¶ms}; - - launch_result = cuda_adapter->launch( - grid, block, smem_size, stream, kernel_params, 0 - ); - - } - else { - return Status::kErrorInternal; - } - } - else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - device_kernel<<>>(params); - } - } - - cudaError_t result = cudaGetLastError(); - if (cudaSuccess == result && Status::kSuccess == launch_result) { - return Status::kSuccess; - } - else { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); - return Status::kErrorInternal; - } - } - - // - // Non-static launch overloads that first create and set the internal params struct of this kernel handle. - // - - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - run( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr, - int32_t kernel_index = 0 - ) { - Status status = initialize(args, workspace, stream, cuda_adapter); - if (Status::kSuccess == status) { - status = run(params_, stream, cuda_adapter, kernel_index); - } - return status; - } - - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - operator()( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - return run(args, workspace, stream, cuda_adapter); - } - - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - run(cudaStream_t stream = nullptr) { - return run(params_, stream); - } - - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - operator()(cudaStream_t stream = nullptr) { - return run(params_, stream); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::device - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h deleted file mode 100644 index 387574b989681ba6f9e5e6fa333dda109b7f7aa6..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h +++ /dev/null @@ -1,270 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Template for device-level Depthwise Convolution -*/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/conv/convolution.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class DirectConvolution { -public: - - using UnderlyingKernel = DirectConvolutionKernel_; - - using ElementA = typename UnderlyingKernel::ElementA; - using LayoutA = typename UnderlyingKernel::LayoutA; - using ElementB = typename UnderlyingKernel::ElementB; - using LayoutB = typename UnderlyingKernel::LayoutB; - using ElementC = typename UnderlyingKernel::ElementC; - using LayoutC = typename UnderlyingKernel::LayoutC; - using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; - using ElementCompute = typename UnderlyingKernel::ElementCompute; - using OperatorClass = typename UnderlyingKernel::OperatorClass; - using ArchTag = typename UnderlyingKernel::ArchTag; - using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; - using WarpShape = typename UnderlyingKernel::WarpShape; - using InstructionShape = typename UnderlyingKernel::InstructionShape; - using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; - using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; - static int const kStages = UnderlyingKernel::kStages; - static int const kConvDim = UnderlyingKernel::kConvDim; - using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; - using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; - using MathOperator = typename UnderlyingKernel::MathOperator; - - static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; - static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; - static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; - - static int const kWarpCount = - (ThreadblockShape::kM / WarpShape::kM) * - (ThreadblockShape::kN / WarpShape::kN) * - (ThreadblockShape::kK / WarpShape::kK); - - /// Argument structure - using Arguments = typename UnderlyingKernel::Arguments; - - using ReorderKernel = typename UnderlyingKernel::ReorderKernel; - - private: - - /// Kernel parameters object - typename UnderlyingKernel::Params params_; - -public: - - /// Constructs Implicit GEMM - DirectConvolution() { } - - /// Determines whether the Implicit GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - - // dispatch to iterators - Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - if (kGroupMode != conv::GroupMode::kDepthwise) { - return Status::kErrorInvalidProblem; - } - - // C and K should be multiple of groups - if (args.problem_size.K != args.problem_size.groups && - args.problem_size.C != args.problem_size.groups) { - return Status::kErrorInvalidProblem; - } - - - static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; - if (kConvolutionalOperator == conv::Operator::kFprop) { - if (args.problem_size.K % kAlignmentC) - return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kDgrad) { - if (args.problem_size.C % kAlignmentC) - return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kWgrad) { - if (args.problem_size.C % kAlignmentC) - return Status::kErrorMisalignedOperand; - } - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape( - threadblock_swizzle.get_tiled_shape( - kConvolutionalOperator, - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices)); - - if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) { - - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - return 0; - } - - /// Initializes GEMM state from arguments. - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - // initialize the params structure from the arguments - params_ = typename UnderlyingKernel::Params( - args, - static_cast(workspace) - ); - - int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Initializes GEMM state from arguments. - Status update(Arguments const &args, void *workspace = nullptr) { - - // update the params structure from the arguments - params_.ptr_A = args.ref_A.data(); - params_.ptr_B = args.ref_B.data(); - params_.ptr_C = args.ref_C.data(); - params_.ptr_D = args.ref_D.data(); - params_.output_op = args.output_op; - params_.ptr_reordered_B = args.ref_reordered_B.data(); - params_.semaphore = static_cast(workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { - - // Launch reorder kernel - if (params_.ptr_reordered_B != nullptr) { - dim3 grid = ReorderKernel::get_grid_shape(params_); - dim3 block = ReorderKernel::get_block_shape(); - - cutlass::arch::synclog_setup(); - cutlass::Kernel<<>>(params_); - } - - // Launch main kernel - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(32 * kWarpCount, 1, 1); - - // Dynamic SMEM size based on input params. - int smem_size = int(params_.get_smem_size()); - - // Make sure we can use that much shared memory. - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (status != cudaSuccess) - return Status::kErrorInternal; - - cutlass::arch::synclog_setup(); - cutlass::Kernel<<>>(params_); - - cudaError_t result = cudaGetLastError(); - - return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) { - status = run(stream); - } - - return status; - } - - int get_smem_size() { return int(params_.get_smem_size()); } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h deleted file mode 100644 index a9aae87bc1c57a20e27298b4f227726dd199a769..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h +++ /dev/null @@ -1,388 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Template for device-level Implicit GEMM Convolution -*/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class ImplicitGemmConvolution { -public: - - using UnderlyingKernel = GetUnderlyingKernel_t; - - using ElementA = typename UnderlyingKernel::ElementA; - using LayoutA = typename UnderlyingKernel::LayoutA; - using ElementB = typename UnderlyingKernel::ElementB; - using LayoutB = typename UnderlyingKernel::LayoutB; - using ElementC = typename UnderlyingKernel::ElementC; - using LayoutC = typename UnderlyingKernel::LayoutC; - using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; - using ElementCompute = typename UnderlyingKernel::ElementCompute; - using OperatorClass = typename UnderlyingKernel::OperatorClass; - using ArchTag = typename UnderlyingKernel::ArchTag; - using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; - using WarpShape = typename UnderlyingKernel::WarpShape; - using InstructionShape = typename UnderlyingKernel::InstructionShape; - using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; - using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; - static int const kStages = UnderlyingKernel::kStages; - static int const kConvDim = UnderlyingKernel::kConvDim; - using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; - using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; - using MathOperator = typename UnderlyingKernel::MathOperator; - - static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; - static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; - static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; - - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - static int const kWarpCount = - (ThreadblockShape::kM / WarpShape::kM) * - (ThreadblockShape::kN / WarpShape::kN) * - (ThreadblockShape::kK / WarpShape::kK); - - /// Argument structure - using Arguments = typename UnderlyingKernel::Arguments; - -private: - - /// Kernel parameters object - typename UnderlyingKernel::Params params_; - -public: - - /// Constructs Implicit GEMM - ImplicitGemmConvolution() { } - - /// Determines whether the Implicit GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - // dispatch to iterators - Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - // Check that tensor sizes don't exceed maximum supported size - if (kConvolutionalOperator == conv::Operator::kFprop) { - if (args.problem_size.activation_size() * sizeof(ElementA) >= - (1ull << 31) || - args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || - args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) { - return Status::kErrorInvalidProblem; - } - } - else if (kConvolutionalOperator == conv::Operator::kDgrad || - kConvolutionalOperator == conv::Operator::kDeconv) { - if (args.problem_size.activation_size() * sizeof(ElementC) >= - (1ull << 31) || - args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || - args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { - return Status::kErrorInvalidProblem; - } - } - else if (kConvolutionalOperator == conv::Operator::kWgrad) { - if (args.problem_size.activation_size() * sizeof(ElementB) >= - (1ull << 31) || - args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) || - args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { - return Status::kErrorInvalidProblem; - } - } - - // check group conv constraint - if (args.problem_size.groups != 1) { - if (kGroupMode == conv::GroupMode::kNone) { - return Status::kErrorInvalidProblem; - } - - // C and K should be multiple of groups - if (args.problem_size.K % args.problem_size.groups || - args.problem_size.C % args.problem_size.groups) { - return Status::kErrorInvalidProblem; - } - - // split-k is not supported - if (args.problem_size.split_k_slices != 1) { - return Status::kErrorInvalidProblem; - } - - int k_per_group = args.problem_size.K / args.problem_size.groups; - // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group - if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { - return Status::kErrorInvalidProblem; - } - // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups - if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { - return Status::kErrorInvalidProblem; - } - - // current optimized iterator algo only supports SingleGroup mode - if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && - kGroupMode != conv::GroupMode::kSingleGroup) { - return Status::kErrorInvalidProblem; - } - } - - static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; - if (kConvolutionalOperator == conv::Operator::kFprop) { - if (args.problem_size.K % kAlignmentC) - return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) { - if (args.problem_size.C % kAlignmentC) - return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kWgrad) { - if (args.problem_size.C % kAlignmentC) - return Status::kErrorMisalignedOperand; - } - - // check for unsupported problem sizes for strided dgrad / deconv implementation - if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && - kStrideSupport == conv::StrideSupport::kStrided) { - // split-k (serial or parallel) is not supported for strided dgrad / deconv - if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) { - return Status::kErrorNotSupported; - } - - // dilation > {1x1} is not supported for strided dgrad / deconv - if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { - return Status::kErrorNotSupported; - } - } - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape( - threadblock_swizzle.get_tiled_shape( - kConvolutionalOperator, - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices)); - - if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) { - - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - - size_t workspace_bytes = 0; - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - kConvolutionalOperator, - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - if(args.split_k_mode == SplitKMode::kParallel) { - - // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. - // The user needs to call a reduction operator to optain the final output tensor - workspace_bytes = - sizeof(ElementAccumulator) * - size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * - size_t(grid_tiled_shape.k()); - } - - else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { - - // Split-K serial: The user workspace is used to store semaphore and serialize writing the - // final reduced output to user's output tensor - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } - - return workspace_bytes; - } - - /// Initializes GEMM state from arguments. - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - - if (args.problem_size.split_k_slices > 1) { - - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - - cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); - - if (status != cudaSuccess) { - return Status::kErrorInternal; - } - } - - // initialize the params structure from the arguments - params_ = typename UnderlyingKernel::Params( - args, - static_cast(workspace) - ); - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - return Status::kSuccess; - } - else { - int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - } - - return Status::kSuccess; - } - - /// Initializes GEMM state from arguments. - Status update(Arguments const &args, void *workspace = nullptr) { - - // update the params structure from the arguments - params_.ptr_A = args.ref_A.data(); - params_.ptr_B = args.ref_B.data(); - params_.ptr_C = args.ref_C.data(); - params_.ptr_D = args.ref_D.data(); - params_.output_op = args.output_op; - params_.semaphore = static_cast(workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - - - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(32 * kWarpCount, 1, 1); - - int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); - cutlass::Status launch_result = cutlass::Status::kSuccess ; - - if constexpr (kEnableCudaHostAdapter) { - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - - void* kernel_params[] = {¶ms_}; - launch_result = cuda_adapter->launch( - grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index - ); - } - else { - launch_result = Status::kErrorInternal; - } - } - else { - cutlass::arch::synclog_setup(); - cutlass::Kernel<<>>(params_); - } - - cudaError_t result = cudaGetLastError(); - if (cudaSuccess == result && Status::kSuccess == launch_result) { - return Status::kSuccess; - } - else { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); - return Status::kErrorInternal; - } - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - return run(stream, cuda_adapter, kernel_index); - } - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - - Status status = initialize(args, workspace, stream, cuda_adapter); - - if (status == Status::kSuccess) { - status = run(stream, cuda_adapter, kernel_index); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h deleted file mode 100644 index efd3dcbad093cf8d11036a63a9b6638d1801aeee..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +++ /dev/null @@ -1,269 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution -*/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/conv/convolution.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class ImplicitGemmConvolutionFusion { -public: - - using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_; - - using ElementA = typename ImplicitGemmFusionKernel::ElementA; - using LayoutA = typename ImplicitGemmFusionKernel::LayoutA; - using ElementB = typename ImplicitGemmFusionKernel::ElementB; - using LayoutB = typename ImplicitGemmFusionKernel::LayoutB; - -// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias; -// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias; - - using ElementC = typename ImplicitGemmFusionKernel::ElementC; - using LayoutC = typename ImplicitGemmFusionKernel::LayoutC; - using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator; - using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute; - using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass; - using ArchTag = typename ImplicitGemmFusionKernel::ArchTag; - using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape; - using WarpShape = typename ImplicitGemmFusionKernel::WarpShape; - using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape; - using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle; - using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp; - static int const kStages = ImplicitGemmFusionKernel::kStages; - static int const kConvDim = ImplicitGemmFusionKernel::kConvDim; - using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator; - using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator; - using MathOperator = typename ImplicitGemmFusionKernel::MathOperator; - - static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm; - - static int const kWarpCount = - (ThreadblockShape::kM / WarpShape::kM) * - (ThreadblockShape::kN / WarpShape::kN) * - (ThreadblockShape::kK / WarpShape::kK); - - /// Argument structure - using Arguments = typename ImplicitGemmFusionKernel::Arguments; - -private: - - /// Kernel parameters object - typename ImplicitGemmFusionKernel::Params params_; - -public: - - /// Constructs Implicit GEMM - ImplicitGemmConvolutionFusion() { } - - /// Determines whether the Implicit GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - - // dispatch to iterators - Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size); - if (Status::kSuccess != status) { - return status; - } - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape( - threadblock_swizzle.get_tiled_shape( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices)); - - if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) { - - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - - size_t workspace_bytes = 0; - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - if(args.split_k_mode == SplitKMode::kParallel) { - - // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. - // The user needs to call a reduction operator to optain the final output tensor - workspace_bytes = - sizeof(ElementAccumulator) * - size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * - size_t(grid_tiled_shape.k()); - } - - else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { - - // Split-K serial: The user workspace is used to store semaphore and serialize writing the - // final reduced output to user's output tensor - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } - - return workspace_bytes; - } - - /// Initializes GEMM state from arguments. - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - if (args.problem_size.split_k_slices > 1) { - - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - - cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); - - if (status != cudaSuccess) { - return Status::kErrorInternal; - } - } - - // initialize the params structure from the arguments - params_ = typename ImplicitGemmFusionKernel::Params( - args, - static_cast(workspace) - ); - - int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Initializes Impicit GEMM state from arguments. - Status update(Arguments const &args, void *workspace = nullptr) { - - // update the params structure from the arguments - params_.ptr_A = args.ref_A.data(); - params_.ptr_B = args.ref_B.data(); - params_.ptr_scale = args.ref_A_scale.data(); - params_.ptr_bias = args.ref_A_bias.data(); - params_.ptr_C = args.ref_C.data(); - params_.ptr_D = args.ref_D.data(); - params_.output_op = args.output_op; - params_.semaphore = static_cast(workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { - - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(32 * kWarpCount, 1, 1); - - int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); - - cutlass::arch::synclog_setup(); - cutlass::Kernel<<>>(params_); - - cudaError_t result = cudaGetLastError(); - - return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) { - status = run(stream); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp deleted file mode 100644 index d569cb1c3e6d6c7da188691a94384d43259d2be0..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp +++ /dev/null @@ -1,136 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/conv/convolution.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/arch/arch.h" - -#include "cute/layout.hpp" -#include "cute/numeric/integral_constant.hpp" - -#include "cutlass/gemm/dispatch_policy.hpp" - -////////////////////////////////////////////////////////////////////////////// - -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv { - -////////////////////////////////////////////////////////////////////////////// - -// -// Policies for categorical dispatch of mainloop against kernel grid schedules -// -struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { }; -struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; -struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; - -// -// Collective Mainloop Policies -// - -// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA -// for fprop -template< - conv::Operator ConvOp_, - int Stages_, - int NumSpatialDimensions_, - class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>>, - class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90, - int PipelineAsyncMmaStages_ = 1 -> -struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { - static constexpr int Stages = Stages_; - static constexpr int NumSpatialDimensions = NumSpatialDimensions_; - static constexpr Operator ConvOp = ConvOp_; - static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; - using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm90; - using Schedule = KernelSchedule; - - static_assert(NumSpatialDimensions >= 1); - static_assert(! (cute::is_same_v || - cute::is_same_v), - "Persistent schedules not support for conv yet."); -}; - - - -// SM100 tensor op kernel schedule -struct KernelImplicitTmaWarpSpecializedSm100 { - static constexpr int SchedulerPipelineStageCount = 0; - static constexpr int AccumulatorPipelineStageCount = 0; -}; - -// Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100 -// but for opting into 1 or 2 SM atoms -struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; -struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; - -struct KernelStridedDgradTmaWs1SmSm100 { }; -struct KernelStridedDgradTmaWs2SmSm100 { }; - -// Policy for implicit gemm kernel -template< - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_ -> -struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 { - static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; - static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; -}; - -// n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop -template< - conv::Operator ConvOp_, - int Stages_, - int NumSpatialDimensions_, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> -> -struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { - static constexpr int Stages = Stages_; - static constexpr int NumSpatialDimensions = NumSpatialDimensions_; - static constexpr Operator ConvOp = ConvOp_; - using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; - using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100; - - static_assert(NumSpatialDimensions >= 1); -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv - -////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp deleted file mode 100644 index af804df30e76a156af33f7095da64614370e466c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp +++ /dev/null @@ -1,65 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/conv/convnd_problem_shape.hpp" -#include "cutlass/detail/dependent_false.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::kernel { - -//////////////////////////////////////////////////////////////////////////////// - -/* - * Stateless universal device CONV kernel type that treats CONV as - * a composition of a collective mainloop and a collective epilogue. -**/ -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileSchedulerTag_ = void, - class Enable = void -> -class ConvUniversal { - static_assert(cutlass::detail::dependent_false, - "Could not find a valid specialization at the kernel layer to dispatch against."); -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::kernel - -//////////////////////////////////////////////////////////////////////////////// -#include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp" -#include "cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp" -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h deleted file mode 100644 index f9647a598799cf233962457f8d2cad7e59e46cf5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h +++ /dev/null @@ -1,322 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/conv/threadblock/threadblock_swizzle.h" -#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" -#include "cutlass/conv/threadblock/implicit_gemm_multistage.h" -#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h" -#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename ArchTag, - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename OutputOp -> -struct DefaultConvEpilogue { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - OutputOp::kCount - >::Epilogue; -}; - -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename OutputOp -> -struct DefaultConvEpilogue< - arch::Sm70, - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp -> { - - using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - OutputOp::kCount - >::Epilogue; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename ArchTag, - typename Shape, - typename WarpMmaSimt, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess, - typename PermuteDLayout = layout::NoPermute, - conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, - int Rank = 4 -> -struct DefaultConvEpilogueWithBroadcastSimt { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt< - Shape, - WarpMmaSimt, - ElementOutput, - ElementTensor, - ElementVector, - OutputOp, - ElementsPerAccess, - false, - PermuteDLayout, - StrideSupport, - Rank - >::Epilogue; -}; - -template < - typename ArchTag, - typename Shape, - typename WarpMmaSimt, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess -> -struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad< - Shape, - WarpMmaSimt, - ElementOutput, - ElementTensor, - ElementVector, - OutputOp, - ElementsPerAccess - >::Epilogue; -}; - -template < - typename ArchTag, - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess -> -struct DefaultConvEpilogueWithBroadcastTensorOp { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - ElementTensor, - ElementVector, - OutputOp, - ElementsPerAccess - >::Epilogue; -}; - -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess -> -struct DefaultConvEpilogueWithBroadcastTensorOp< - arch::Sm70, - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - ElementTensor, - ElementVector, - OutputOp, - ElementsPerAccess - > { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - ElementTensor, - ElementVector, - OutputOp, - ElementsPerAccess - >::Epilogue; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ArchTag, - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename OutputOp, - typename ReductionOp, - int ElementsPerAccess -> -struct DefaultConvEpilogueWithReductionTensorOp { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - OutputOp, - ReductionOp, - ElementsPerAccess - >::Epilogue; -}; - -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename OutputOp, - typename ReductionOp, - int ElementsPerAccess -> -struct DefaultConvEpilogueWithReductionTensorOp< - arch::Sm70, - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - OutputOp, - ReductionOp, - ElementsPerAccess - > { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - ElementOutput, - OutputOp, - ReductionOp, - ElementsPerAccess - >::Epilogue; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Defaults for strided Dgrad -template < - typename ArchTag, - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename OutputOp -> -struct DefaultConvEpilogueStridedDgrad { - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - OutputOp::kCount - >::Epilogue; -}; - -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename OutputOp -> -struct DefaultConvEpilogueStridedDgrad< - arch::Sm70, - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp -> { - - using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - OutputOp::kCount - >::Epilogue; -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h deleted file mode 100644 index 27a96a5602494e2abe3980b3d07d54c49dcb9932..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h +++ /dev/null @@ -1,1927 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dDgrad -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultConv2dDgrad; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and -// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kStrided, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kStrided, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kStrided, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kStrided, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided -// and multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity -// 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided -// and multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and -// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kStrided, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kStrided, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kStrided, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kStrided, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity -// 2 stage pipeline -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kUnity - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kStrided - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kStrided - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kUnity - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kUnity - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - StrideSupport::kUnity - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad - >; - -}; - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h deleted file mode 100644 index 932d1abdc6e2c80a4a1e8d1eb805cbedbe5ac78a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ /dev/null @@ -1,2007 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultConv2dFprop; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kFixedChannels, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kFixedChannels, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kFewChannels, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kFewChannels, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline with interleaved layout. -template < - typename ElementA, - typename ElementB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB, - int InterleavedK -> -struct DefaultConv2dFprop < - ElementA, - layout::TensorNCxHWx, - ElementB, - layout::TensorCxRSKx, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, - ElementB, layout::RowMajorInterleaved, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, - Stages, MathOperatorTag, true>; - - // Define iterators over tiles from the A operand - // Note GEMM shared memory threadmap is used here because conv global memory - // layout needs to be mapped to fprop which is similar to the crosswise - // layout which is used by the interleaved GEMM shared memory threadmap. - // The Interleaved GEMM global memory layout is similar to the congruous - // layout. - using ThreadMapA = typename MmaCore::SmemThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, layout::TensorNCxHWx, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - // Note GEMM shared memory threadmap is used here because conv global memory - // layout needs to be mapped to fprop which is similar to the crosswise - // layout which is used by the interleaved GEMM shared memory threadmap. - // The Interleaved GEMM global memory layout is similar to the congruous - // layout. - using ThreadMapB = typename MmaCore::SmemThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, layout::TensorCxRSKx, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - InterleavedK - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -/// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -/// pipeline with interleaved layout. -template < - typename ElementA, - typename ElementB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB, - int InterleavedK -> -struct DefaultConv2dFprop < - ElementA, - layout::TensorNCxHWx, - ElementB, - layout::TensorCxRSKx, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, - ElementB, layout::RowMajorInterleaved, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, - 2, MathOperatorTag, true>; - - // Define iterators over tiles from the A operand - // Note GEMM shared memory threadmap is used here because conv global memory - // layout needs to be mapped to fprop which is similar to the crosswise - // layout which is used by the interleaved GEMM shared memory threadmap. - // The Interleaved GEMM global memory layout is similar to the congruous - // layout. - using ThreadMapA = typename MmaCore::SmemThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, layout::TensorNCxHWx, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - // Note GEMM shared memory threadmap is used here because conv global memory - // layout needs to be mapped to fprop which is similar to the crosswise - // layout which is used by the interleaved GEMM shared memory threadmap. - // The Interleaved GEMM global memory layout is similar to the congruous - // layout. - using ThreadMapB = typename MmaCore::SmemThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, layout::TensorCxRSKx, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - InterleavedK - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -/// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag - >; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -// multistage pipeline with interleaved layout. -template < - typename ElementA, - typename ElementB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB, - int InterleavedK -> -struct DefaultConv2dFprop < - ElementA, - layout::TensorNCxHWx, - ElementB, - layout::TensorCxRSKx, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, - ElementB, layout::RowMajorInterleaved, ElementAccumulator, LayoutC, arch::OpClassTensorOp, - Stages, MathOperatorTag, true - >; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::SmemThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - layout::TensorNCxHWx, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::SmemThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - layout::TensorCxRSKx, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - InterleavedK - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -/// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -/// pipeline with interleaved layout. -template < - typename ElementA, - typename ElementB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB, - int InterleavedK -> -struct DefaultConv2dFprop < - ElementA, - layout::TensorNCxHWx, - ElementB, - layout::TensorCxRSKx, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, - ElementB, layout::RowMajorInterleaved, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, - 2, MathOperatorTag, true>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::SmemThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, layout::TensorNCxHWx, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::SmemThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, layout::TensorCxRSKx, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - InterleavedK - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h deleted file mode 100644 index 85b142a0e27d3c39d2d742c1709582ea3156b801..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ /dev/null @@ -1,357 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief - Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution - definitions that combine threadblock-scoped matrix multiply-add with the - appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for fused batch norm and Conv2dFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity -> struct DefaultConv2dFpropFusion; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv2dFpropFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - using SmemIteratorScaleBias = - cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static int const kThreadCount = 32; - - // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< - MatrixShape, ElementScaleBias, - LayoutScaleBias, MatrixShape, - typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, - MmaCore::WarpCount::kK>; - - // Define the Mma - using Mma = threadblock::ImplicitGemmFpropFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - IteratorScaleBias, - SmemIteratorScaleBias, - arch::CacheOperation::Always, - MmaPolicy, - WarpIteratorScaleBias, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -/// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv2dFpropFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag - >; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - using SmemIteratorScaleBias = - cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static int const kThreadCount = 32; - - // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< - MatrixShape, ElementScaleBias, - LayoutScaleBias, MatrixShape, - typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, - MmaCore::WarpCount::kK>; - - // Define the Mma - using Mma = threadblock::ImplicitGemmFpropFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - IteratorScaleBias, - SmemIteratorScaleBias, - arch::CacheOperation::Always, - MmaPolicy, - WarpIteratorScaleBias, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h deleted file mode 100644 index ccc751535c7a8c2c2f49b8d34f9d0e9a8edbd90e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h +++ /dev/null @@ -1,127 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines a default configuration for convolution with absolute maximum calculation. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_absmax.h" -#include "cutlass/epilogue/threadblock/epilogue_with_absmax.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultConv2dFpropWithAbsMax { - - using ImplicitGemmBase = typename DefaultConv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithAbsMax< - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - typename EpilogueOutputOp::ElementAuxOutput, - ElementC, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithAbsMax< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h deleted file mode 100644 index b7fca981b0e0b44dca2b9add89808ac2b036d021..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ /dev/null @@ -1,221 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultConv2dFpropWithBroadcast { - - using ImplicitGemmBase = typename DefaultConv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dFpropWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultConv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h deleted file mode 100644 index 5c2c7ffc700b089e449d4f18008c26cdb8d6c81a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ /dev/null @@ -1,130 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename EpilogueReductionOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultConv2dFpropWithReduction { - - using ImplicitGemmBase = typename DefaultConv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - EpilogueOutputOp, - EpilogueReductionOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h deleted file mode 100644 index 99e353d80a0b3b37818371737c8189eee6b5ed38..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ /dev/null @@ -1,622 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dGroupFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::GroupMode GroupMode, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultConv2dGroupFprop; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline that supports all GroupMode. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::GroupMode GroupMode, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dGroupFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - GroupMode, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA, - GroupMode - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB, - GroupMode - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - GroupMode - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and -/// 2 stage pipeline that supports all GroupMode. - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::GroupMode GroupMode, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dGroupFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - GroupMode, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA, - GroupMode - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB, - GroupMode - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - GroupMode - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage -/// pipeline that supports GroupMode::kSingleGroup. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dGroupFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - GroupMode::kSingleGroup, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - GroupMode::kSingleGroup - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and -/// 2 stage pipeline that supports GroupMode::kSingleGroup. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dGroupFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - GroupMode::kSingleGroup, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - static_assert(platform::is_same::value, - "Current group conv only support NHWC layout"); - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - GroupMode::kSingleGroup - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h deleted file mode 100644 index d55d453eb02675d0b626865b6625dc4bf2b12e92..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h +++ /dev/null @@ -1,1011 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultConv2dWgrad; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and two -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - AccessTypeA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - AccessTypeB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and two -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::AlignedArray; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - AccessTypeA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - AccessTypeB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - kPartitionsK, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AccessTypeA, - int AccessTypeB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AccessTypeA, - AccessTypeB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AccessTypeA, - int AccessTypeB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AccessTypeA, - AccessTypeB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AccessTypeA, - int AccessTypeB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport, - AccessTypeA, - AccessTypeB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AccessTypeA, - int AccessTypeB -> -struct DefaultConv2dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - AccessTypeA, - AccessTypeB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h deleted file mode 100644 index 83b680ec3591de39470013d71b808f356306b2f0..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h +++ /dev/null @@ -1,325 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided -> struct DefaultConv2dWgradFusion; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv2dWgradFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< - cutlass::MatrixShape<1, WarpShape::kN>, - ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmWgradFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - IteratorScaleBias, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv2dWgradFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< - cutlass::MatrixShape<1, WarpShape::kN>, - ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmWgradFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - IteratorScaleBias, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h deleted file mode 100644 index 309924cebafe82df1651b0fb5542eb14dc6c5388..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h +++ /dev/null @@ -1,736 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h" - -#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dDgrad -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided -> struct DefaultConv3dDgrad; - -/// Defines a kernel for Conv3dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -// and multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -// and multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - - using IteratorB = - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - // ThreadMapB, - // StrideSupport::kUnity - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - // cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - // > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - // cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - // > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dDgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - // cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - // > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - // cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - // ThreadMapB, - // StrideSupport::kUnity - // > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDgrad, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h deleted file mode 100644 index 4b6709f08a4b2e93a0e3b93e1a343896368451c2..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ /dev/null @@ -1,981 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" - - -#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity -> struct DefaultConv3dFprop; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Analytic Iterator Algorithm -/// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm -/// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport -> -struct DefaultConv3dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h deleted file mode 100644 index 513de059c6591a47fbf2c75f81d1400c96fe9d48..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h +++ /dev/null @@ -1,360 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution - definitions that combine threadblock-scoped matrix multiply-add with the - appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for fused batch norm and Conv3dFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity -> struct DefaultConv3dFpropFusion; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassTensorOp convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage -/// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dFpropFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - using SmemIteratorScaleBias = - cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static int const kThreadCount = 32; - - // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< - MatrixShape, ElementScaleBias, - LayoutScaleBias, MatrixShape, - typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, - MmaCore::WarpCount::kK>; - - // Define the Mma - using Mma = threadblock::ImplicitGemmFpropFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - IteratorScaleBias, - SmemIteratorScaleBias, - arch::CacheOperation::Always, - MmaPolicy, - WarpIteratorScaleBias, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and -/// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementScaleBias, - typename LayoutScaleBias, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dFpropFusion < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementScaleBias, - LayoutScaleBias, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag - >; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - /// Define iterators over tiles from scale/bias vectors - using IteratorScaleBias = - cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - using SmemIteratorScaleBias = - cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - static int const kThreadCount = 32; - - // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< - MatrixShape, ElementScaleBias, - LayoutScaleBias, MatrixShape, - typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, - MmaCore::WarpCount::kK>; - - // Define the Mma - using Mma = threadblock::ImplicitGemmFpropFusionMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, - IteratorScaleBias, - SmemIteratorScaleBias, - arch::CacheOperation::Always, - MmaPolicy, - WarpIteratorScaleBias, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h deleted file mode 100644 index 2fb12c2a502f9af2aa5383288e6695a108abdf60..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h +++ /dev/null @@ -1,222 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_conv3d_fprop.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultConv3dFpropWithBroadcast { - - using ImplicitGemmBase = typename DefaultConv3dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultConv3dFpropWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultConv3dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess, - layout::NoPermute, - StrideSupport, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h deleted file mode 100644 index 6b50d2087e20889a934eaf34c7f120badff8a435..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h +++ /dev/null @@ -1,936 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided -> struct DefaultConv3dWgrad; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and two -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and multistage -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and two -// pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv3dWgrad < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - ThreadMapB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kWgrad, - Conv3dProblemSize - >; - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h deleted file mode 100644 index a58046ffa414e6556d14b20c5402fb5d82cfbf64..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h +++ /dev/null @@ -1,999 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Deconv2d -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultDeconv2d; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - conv::GroupMode::kNone, - true /*IsDeconv*/ - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - conv::GroupMode::kNone, - true /*IsDeconv*/ - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - true /*IsDeconv*/ - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - true /*IsDeconv*/ - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kUnity - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - conv::GroupMode::kNone, - true /*IsDeconv*/ - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - conv::GroupMode::kNone, - true /*IsDeconv*/ - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - true /*IsDeconv*/ - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 4 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - cutlass::AlignedArray, - true /*IsDeconv*/ - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; - -}; - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h deleted file mode 100644 index e62187e3680e55a71d77bf4fee19276357753f98..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h +++ /dev/null @@ -1,305 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_deconv2d.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultDeconv2dWithBroadcast { - - using ImplicitGemmBase = typename DefaultDeconv2d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Deconv2d specialization, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2dWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultDeconv2d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv2dWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultDeconv2d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h deleted file mode 100644 index cb7ca07e6eb9b18f3006d51e742772f755852e23..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h +++ /dev/null @@ -1,541 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" - -#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" - -#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Deconv3d -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided -> struct DefaultDeconv3d; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultDeconv3d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - true /*IsDeconv*/ - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultDeconv3d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB, - true /*IsDeconv*/ - // ThreadMapB, - // StrideSupport::kUnity - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Always, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultDeconv3d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - // cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - conv::StrideSupport::kStrided - // > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - // cutlass::conv::threadblock::TileIteratorStridedDgrad< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, - ThreadMapB, - true /*IsDeconv*/ - // > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultDeconv3d < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport::kUnity -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - // cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - ThreadMapA, - StrideSupport::kUnity - // > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - // cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB, - true /*IsDeconv*/ - // ThreadMapB, - // StrideSupport::kUnity - // > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - false, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h deleted file mode 100644 index e25c8b2eee551252b902e0c0845416b753194df1..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h +++ /dev/null @@ -1,309 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/conv/kernel/default_deconv3d.h" -#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> -struct DefaultDeconv3dWithBroadcast { - - using ImplicitGemmBase = typename DefaultDeconv3d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - StrideSupport - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ImplicitGemmBase::Epilogue::kPartitionsK, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Deconv3d specialization for Analytic IteratorAlgorithm, -/// multi-stage pipeline, and FFMA-based mainloop for SM80 - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv3dWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kUnity, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultDeconv3d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kUnity - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm, - int AlignmentA, - int AlignmentB -> -struct DefaultDeconv3dWithBroadcast < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kStrided, - AlignmentA, - AlignmentB -> { - - using ImplicitGemmBase = typename DefaultDeconv3d< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm, - conv::StrideSupport::kStrided - >::Kernel; - - // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< - ArchTag, - typename ImplicitGemmBase::Epilogue::Shape, - typename ImplicitGemmBase::Epilogue::WarpMmaOperator, - ElementC, - typename EpilogueOutputOp::ElementT, - typename EpilogueOutputOp::ElementVector, - EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess, - layout::NoPermute, - StrideSupport::kStrided, - 5 - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< - typename ImplicitGemmBase::Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kDeconv, - Conv3dProblemSize - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h deleted file mode 100644 index ba70813e4c94104522a05897a60811d26ae3c6a4..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h +++ /dev/null @@ -1,588 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped - matrix multiply-add with the appropriate threadblock-scoped epilogue. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d.h" -#include "cutlass/conv/kernel/direct_convolution.h" - -#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" - -#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" - -// Direct Conv Related Header files -#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" -#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" - -#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" -#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for DepthwiseFprop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value -> struct DefaultDepthwiseFprop; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for DepthwiseFprop with direct convolution algorithm -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename OperatorClass, - typename ArchTag, - typename ThreadblockShape, - typename ThreadBlockOutputShape, - typename FilterShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kUnity, - // MatrixShape - typename StrideShape = cutlass::MatrixShape<-1, -1>, - // MatrixShape< Height, Width> - typename DilationShape = cutlass::MatrixShape<-1, -1>, - /// Access granularity of A matrix in units of elements - int AlignmentA = 128 / cutlass::sizeof_bits::value, - /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value -> struct DefaultDepthwiseDirect2dConvFprop; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// OpClassSimt convolutions -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - int AlignmentA, - int AlignmentB -> -struct DefaultDepthwiseFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, // cutlass::arch::OpMultiplyAdd - IteratorAlgorithm::kAnalytic, - StrideSupport, - AlignmentA, - AlignmentB -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< - ThreadblockShape, - WarpShape, - InstructionShape, - ElementA, - layout::RowMajor, - ElementB, - layout::ColumnMajor, - ElementAccumulator, - layout::RowMajor, - arch::OpClassSimt, - 128, - sizeof_bits::value, - 2, - MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementA, LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB, - AccessTypeB, - cutlass::conv::GroupMode::kDepthwise - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::DepthwiseFpropPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< - ThreadblockShape, - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - cutlass::conv::GroupMode::kDepthwise - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -/// multiple stage pipeline, and SIMT-based mainloop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename ThreadBlockOutputShape, - typename FilterShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - typename StrideShape, - typename DilationShape, - int AlignmentA, - int AlignmentB -> -struct DefaultDepthwiseDirect2dConvFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - ThreadBlockOutputShape, - FilterShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized, - StrideSupport, - StrideShape, - DilationShape, - AlignmentA, - AlignmentB -> { - // One warp handles the entrie groups per cta. - static_assert(ThreadblockShape::kN == WarpShape::kN, - "ThreadblockShape::kN should be same as WarpShape::kN "); - static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, - "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); - static_assert(ThreadblockShape::kM % WarpShape::kM == 0, - "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); - static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); - - // Define the core components from GEMM - using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< - ThreadblockShape, - ThreadBlockOutputShape, - FilterShape, - WarpShape, - InstructionShape, - ElementA, - layout::RowMajor, - ElementB, - layout::ColumnMajor, - ElementAccumulator, - layout::RowMajor, - arch::OpClassSimt, - 128, - 128, - Stages, - MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< - cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> - ThreadBlockOutputShape, - ElementA, LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - using ThreadOutputShape = typename MmaCore::ThreadOutputShape; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * AlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< - ThreadblockShape, // < outputShape:KMNK, groups per cta> - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - ThreadOutputShape, - ThreadBlockOutputShape - >::Epilogue; - - // Define the Mma - using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - CacheOpA, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages, - Epilogue - >; - - // Define the kernel - using Kernel = cutlass::conv::kernel::DirectConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - cutlass::conv::GroupMode::kDepthwise, - ThreadBlockOutputShape - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -/// multiple stage pipeline, and SIMT-based mainloop -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename ThreadBlockOutputShape, - typename FilterShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag, - conv::StrideSupport StrideSupport, - typename StrideShape, - typename DilationShape, - int AlignmentA, - int AlignmentB -> -struct DefaultDepthwiseDirect2dConvFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassSimt, - ArchTag, - ThreadblockShape, - ThreadBlockOutputShape, - FilterShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kFixedStrideDilation, - StrideSupport, - StrideShape, - DilationShape, - AlignmentA, - AlignmentB -> { - - - - // One warp handles the entrie groups per cta. - static_assert(ThreadblockShape::kN == WarpShape::kN, - "ThreadblockShape::kN should be same as WarpShape::kN "); - static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, - "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); - static_assert(ThreadblockShape::kM % WarpShape::kM == 0, - "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); - static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); - - static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); - static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); - - // Activations loaded by threadblock - static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + - (FilterShape::kRow - 1) * DilationShape::kRow + 1; - - static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + - (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; - - using ActivationShape = - cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; - - // Define the core components from GEMM - using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< - ThreadblockShape, - ThreadBlockOutputShape, - FilterShape, - WarpShape, - InstructionShape, - ElementA, - layout::RowMajor, - ElementB, - layout::ColumnMajor, - ElementAccumulator, - layout::RowMajor, - arch::OpClassSimt, - 128, - 128, - Stages, - MathOperatorTag, - IteratorAlgorithm::kFixedStrideDilation, - StrideShape, - DilationShape, - ActivationShape>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< - cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> - ThreadBlockOutputShape, - StrideShape, - DilationShape, - ActivationShape, - ElementA, LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::AlignedArray; - using IteratorB = - cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; - using MmaPolicy = typename MmaCore::MmaPolicy; - using ThreadOutputShape = typename MmaCore::ThreadOutputShape; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * AlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< - ThreadblockShape, // < outputShape:KMNK, groups per cta> - WarpMmaSimtOp, - EpilogueOutputOp, - EpilogueOutputOp::kCount, - ThreadOutputShape, - ThreadBlockOutputShape - >::Epilogue; - - // Define the Mma - using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< - ThreadblockShape, - IteratorA, - SmemIteratorA, - CacheOpA, - IteratorB, - SmemIteratorB, - CacheOpB, - MmaPolicy, - Stages, - Epilogue, - IteratorAlgorithm::kFixedStrideDilation - >; - - // Define the kernel - using Kernel = cutlass::conv::kernel::DirectConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop, - Conv2dProblemSize, - cutlass::conv::GroupMode::kDepthwise, - ThreadBlockOutputShape - >; -}; - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h deleted file mode 100644 index 8c04988790b9b03e41e9c2245dbdf2e5e8af493b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h +++ /dev/null @@ -1,506 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multi-staged Depthwise Convolution kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/platform/platform.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure -template > ///! OutputShape per ThreadBlock -struct DirectConvolutionParams { - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - static Operator const kConvolutionalOperator = ConvOperator; - using ConvProblemSize = ConvProblemSize_; - using Arguments = Arguments_; - using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; - - using ThreadblockShape = typename Mma::Shape; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static conv::GroupMode const kGroupMode = GroupMode_; - static int const kStages = Mma::kStages; - - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - gemm::GemmCoord implicit_gemm_problem_size; - int swizzle_log_tile; - int smem_size_; - - int gemm_k_iterations; - int gemm_k_iterations_per_channel; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Mma::IteratorB::Element *ptr_reordered_B; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; - int split_k_slices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} - - /// - CUTLASS_HOST_DEVICE - DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) - : problem_size(args.problem_size), - implicit_gemm_problem_size( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), - ptr_B(args.ref_B.data()), - ptr_reordered_B(args.ref_reordered_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), - ptr_D(args.ref_D.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode), - split_k_slices(args.problem_size.split_k_slices) { - gemm_k_iterations = - depthwise_gemm_k_iterations(kConvolutionalOperator, - ThreadblockShape::kK, - args.problem_size, - kIteratorAlgorithm, - kGroupMode, - ThreadblockShape::kN); - - gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( - kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - kConvolutionalOperator, - problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - - // Dynamic SMEM usage because stride and dilation are runtime params. - smem_size_ = (cutlass::platform::max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); - } - - CUTLASS_HOST_DEVICE - int get_smem_size() { - // Dynamic Smem Size - return smem_size_; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct ReorderKernel { - using Params = Params_; - using ElementB = ElementB_; - - union SharedStorage {}; - - static unsigned int const kReorderKernelThreadPerCTA = 128; - - CUTLASS_HOST_DEVICE - ReorderKernel() {} - - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(Params const ¶ms) { - return dim3{static_cast( - (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / - kReorderKernelThreadPerCTA), - 1, - 1}; - } - - CUTLASS_HOST_DEVICE - static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } - - CUTLASS_HOST_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - int64_t m = static_cast(params.problem_size.groups); - int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); - const ElementB *src_with_type = static_cast(params.ptr_B); - ElementB *dst_with_type = static_cast(params.ptr_reordered_B); - - int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; - int64_t index_m = linear_index / n; - int64_t index_n = linear_index % n; - int64_t new_linear_index = index_m + index_n * m; - - if (linear_index < m * n) { - dst_with_type[new_linear_index] = src_with_type[linear_index]; - } - return; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem - conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode - typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> -> -struct DirectConvolution { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueOutputOp::ElementOutput; - - /// Set output tensor C layout - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefC = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = GroupMode_; - - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefB ref_reordered_B; - TensorRefC ref_C; - TensorRefC ref_D; - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - typename EpilogueOutputOp::Params const & output_op, - TensorRefB const & ref_reordered_B = nullptr, - SplitKMode const & split_k_mode = SplitKMode::kSerial - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_C(ref_C), - ref_D(ref_D), - output_op(output_op), - ref_reordered_B(ref_reordered_B), - split_k_mode(split_k_mode) - { - - } - - }; - - using Params = - typename cutlass::conv::kernel::DirectConvolutionParams; - - using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - DirectConvolution() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if threadblock is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - int iterator_column_offset = 0; - int filter_row_offset = 0; - if (kGroupMode != GroupMode::kNone) { - if (kGroupMode == GroupMode::kDepthwise) { - iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; - } - } - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - MatrixCoord( - threadblock_tile_idx.m() + threadblock_tile_idx.k(), - iterator_column_offset - ) - ); - - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_reordered_B, - thread_idx, - MatrixCoord( - filter_row_offset, - iterator_column_offset - ) - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() + threadblock_tile_idx.k(), - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - - // Compute threadblock-scoped matrix multiply-add - // Epilogue is fused in the mainloop - mma(params.gemm_k_iterations, - accumulators, - iterator_A, - params.iterator_A, - iterator_B, - params.iterator_B, - accumulators, - epilogue, - output_op, - iterator_D, - iterator_C, - params.split_k_slices); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h deleted file mode 100644 index d3fa0e907bb94c2716861395324b2da0346cebde..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ /dev/null @@ -1,455 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined Implicit GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) - typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem - conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode -> -struct ImplicitGemmConvolution { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueOutputOp::ElementOutput; - - /// Set output tensor C layout - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename ArchMmaOperator::Shape; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefC = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = GroupMode_; - - /// Wgrad C stride idx for implicit gemm algorithm - // Conv2d row-major matrix C (KxRSC) - // Conv3d row-major matrix C (KxTRSC) - static int const kWgradCStrideIdx = - platform::is_same::value ? 2 : 3; - - /// This chooses the appropriate stride element of the C tensor. - static int const kTensorCStrideIdx = - (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefC ref_C; - TensorRefC ref_D; - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - typename EpilogueOutputOp::Params const & output_op, - SplitKMode const & split_k_mode = SplitKMode::kSerial - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_C(ref_C), - ref_D(ref_D), - output_op(output_op), - split_k_mode(split_k_mode) - { - - } - - }; - - /// Parameters structure - struct Params { - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - gemm::GemmCoord implicit_gemm_problem_size; - int swizzle_log_tile; - - int gemm_k_iterations; - int gemm_k_iterations_per_channel; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } - - /// - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - int *semaphore = nullptr - ): - problem_size(args.problem_size), - implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(args.problem_size, args.ref_B.layout()), - ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), - ptr_D(args.ref_D.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode) - { - gemm_k_iterations = implicit_gemm_k_iterations( - kConvolutionalOperator, - ThreadblockShape::kK, - args.problem_size, - kIteratorAlgorithm, - kGroupMode, - ThreadblockShape::kN); - - gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( - kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - implicit_gemm_problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - ImplicitGemmConvolution() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; - if (kGroupMode != GroupMode::kNone) { - if (kGroupMode != GroupMode::kDepthwise) { - int k_per_group = params.problem_size.K / params.problem_size.groups; - int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; - int channels_per_group = params.problem_size.C / params.problem_size.groups; - iterator_A_column_offset += group_idx * channels_per_group; - } else { - iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; - } - } - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - MatrixCoord( - threadblock_tile_idx.m() * Mma::Shape::kM, - iterator_A_column_offset - ) - ); - - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_B, - thread_idx, - MatrixCoord( - threadblock_tile_idx.k() * Mma::Shape::kK, - threadblock_tile_idx.n() * Mma::Shape::kN - ) - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Construct the semaphore. - int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); - } - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_idx.k()); - - } - // Each split-k-slice writes to a unique tensor location - else if (params.split_k_mode == SplitKMode::kParallel) { - iterator_D.add_pointer_offset(threadblock_tile_idx.k() * - cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); - } - - // Run efficient epilogue - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_idx.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h deleted file mode 100644 index 5451c176f4027bc40a3ec3466efe69dea18f5342..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ /dev/null @@ -1,461 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -> -struct ImplicitGemmConvolutionFusion { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - - using ElementScaleBias = typename Mma::IteratorScaleBias::Element; - using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout; - - using ElementC = typename EpilogueOutputOp::ElementOutput; - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename ArchMmaOperator::Shape; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef; - using TensorRefC = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; - - /// Wgrad C stride idx for implicit gemm algorithm - // Conv2d row-major matrix C (KxRSC) - // Conv3d row-major matrix C (KxTRSC) - static int const kWgradCStrideIdx = - platform::is_same::value ? 2 : 3; - - /// This chooses the appropriate stride element of the C tensor. - static int const kTensorCStrideIdx = - (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefScaleBias ref_scale; - TensorRefScaleBias ref_bias; - TensorRefC ref_C; - TensorRefC ref_D; - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefScaleBias const & ref_scale, - TensorRefScaleBias const & ref_bias, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - typename EpilogueOutputOp::Params const & output_op, - SplitKMode const & split_k_mode = SplitKMode::kSerial - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_scale(ref_scale), - ref_bias(ref_bias), - ref_C(ref_C), - ref_D(ref_D), - output_op(output_op), - split_k_mode(split_k_mode) - { - - } - - }; - - /// Parameters structure - struct Params { - ConvProblemSize problem_size{}; - cutlass::gemm::GemmCoord grid_tiled_shape{}; - gemm::GemmCoord implicit_gemm_problem_size{}; - int swizzle_log_tile{0}; - int gemm_k_iterations{0}; - typename Mma::IteratorA::Params iterator_A{}; - typename Mma::IteratorA::Element const *ptr_A = nullptr; - typename Mma::IteratorB::Params iterator_B{}; - typename Mma::IteratorB::Element const *ptr_B = nullptr; - typename Mma::IteratorScaleBias::Params iterator_scale_bias{}; - typename Mma::IteratorScaleBias::Element const *ptr_scale = nullptr; - typename Mma::IteratorScaleBias::Element const *ptr_bias = nullptr; - typename Epilogue::OutputTileIterator::Params iterator_C {}; - typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; - typename Epilogue::OutputTileIterator::Params iterator_D {}; - typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; - typename EpilogueOutputOp::Params output_op {}; - int *semaphore = nullptr; - SplitKMode split_k_mode {}; - - // - // Methods - // - Params() = default; - - /// - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - int *semaphore = nullptr - ): - problem_size(args.problem_size), - implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(args.problem_size, args.ref_B.layout()), - ptr_B(args.ref_B.data()), - iterator_scale_bias(args.problem_size, args.ref_scale.layout()), - ptr_scale(args.ref_scale.data()), - ptr_bias(args.ref_bias.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), - ptr_D(args.ref_D.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode) - { - gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - implicit_gemm_problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - ImplicitGemmConvolutionFusion() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A operand - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - MatrixCoord( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK - ) - ); - - // Construct iterators to B operand - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_B, - thread_idx, - MatrixCoord( - threadblock_tile_idx.k() * Mma::Shape::kK, - threadblock_tile_idx.n() * Mma::Shape::kN - ) - ); - - // Construct iterators to A scale/bias vector - typename Mma::IteratorScaleBias iterator_scale_bias( - params.iterator_scale_bias, - params.problem_size, - params.ptr_scale, - params.ptr_bias, - thread_idx, - MatrixCoord( - 0, (kConvolutionalOperator == conv::Operator::kFprop) ? - (threadblock_tile_idx.k() * Mma::Shape::kK) : - // Wgrad - (threadblock_tile_idx.n() * Mma::Shape::kN) - ) - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, - iterator_B, iterator_scale_bias, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Construct the semaphore. - int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); - } - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_idx.k()); - - } - // Each split-k-slice writes to a unique tensor location - else if (params.split_k_mode == SplitKMode::kParallel) { - iterator_D.add_pointer_offset(threadblock_tile_idx.k() * - cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); - } - - // Run efficient epilogue - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_idx.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h deleted file mode 100644 index 071854cd629e26417ca987bc24681665c8d30702..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ /dev/null @@ -1,492 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined Implicit GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -> -struct ImplicitGemmConvolutionStridedDgrad { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueOutputOp::ElementOutput; - - /// Set output tensor C layout - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename ArchMmaOperator::Shape; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefC = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; - - /// Wgrad C stride idx for implicit gemm algorithm - // Conv2d row-major matrix C (KxRSC) - // Conv3d row-major matrix C (KxTRSC) - static int const kWgradCStrideIdx = - platform::is_same::value ? 2 : 3; - - /// This chooses the appropriate stride element of the C tensor. - static int const kTensorCStrideIdx = - (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); - - // Strided dgrad uses a specialized threadblock swizzle for functionality and performance - static_assert((platform::is_same::value) || - (platform::is_same>::value) || - (platform::is_same>::value) || - (platform::is_same>::value), - "Needs ThreadblockSwizzle type specialized for strided dgrad"); - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size{}; - TensorRefA ref_A{}; - TensorRefB ref_B{}; - TensorRefC ref_C{}; - TensorRefC ref_D{}; - typename EpilogueOutputOp::Params output_op{}; - SplitKMode split_k_mode{}; - - // - // Methods - // - - /// Default ctor - Arguments() = default; - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - typename EpilogueOutputOp::Params const & output_op, - SplitKMode const & split_k_mode = SplitKMode::kSerial - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_C(ref_C), - ref_D(ref_D), - output_op(output_op), - split_k_mode(split_k_mode) - { - - } - - }; - - /// Parameters structure - struct Params { - ConvProblemSize problem_size{}; - cutlass::gemm::GemmCoord grid_tiled_shape{}; - int swizzle_log_tile{0}; - FastDivmod stride_h_divmod{}; - FastDivmod stride_w_divmod{}; - int gemm_k_iterations{0}; - typename Mma::IteratorA::Params iterator_A{}; - typename Mma::IteratorA::Element const *ptr_A = nullptr; - typename Mma::IteratorB::Params iterator_B{}; - typename Mma::IteratorB::Element const *ptr_B = nullptr; - typename Epilogue::OutputTileIterator::Params iterator_C{}; - typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; - typename Epilogue::OutputTileIterator::Params iterator_D{}; - typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; - typename EpilogueOutputOp::Params output_op {}; - int *semaphore = nullptr; - SplitKMode split_k_mode {}; - - // - // Methods - // - Params() = default; - - /// - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - int *semaphore = nullptr - ): - problem_size(args.problem_size), - stride_h_divmod(args.problem_size.stride_h), - stride_w_divmod(args.problem_size.stride_w), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(args.problem_size, args.ref_B.layout()), - ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), - ptr_D(args.ref_D.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode) - { - gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - kConvolutionalOperator, - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - ImplicitGemmConvolutionStridedDgrad() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Compute starting filter position for strided dgrad - int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, - ThreadblockShape::kM); - int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); - - - // The subsequent fast_divmod() operations are equivalent to the following logical computation: - // - // int start_r = filter_tile_m / (params.problem_size.stride_w); - // int start_s = filter_tile_m % (params.problem_size.stride_w); - - int start_r, start_s; - params.stride_w_divmod(start_r, start_s, filter_tile_m); - - int filter_r = start_r; - int filter_s = start_s; - - if (params.problem_size.mode == Mode::kConvolution) { - filter_r = (params.problem_size.R - 1 - filter_r); - filter_s = (params.problem_size.S - 1 - filter_s); - } - - // Starting h, w positions for filter position in gemm_k=0 - int start_h, start_w; - strided_dgrad_starting_coords( - params.problem_size, - params.stride_h_divmod, params.stride_w_divmod, - filter_r, filter_s, - start_h, start_w); - - if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { - return; - } - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA - if (start_r < params.problem_size.R && start_s < params.problem_size.S) { - // Scale gemm_k_iterations for strided dgrad - int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) - ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - params.stride_h_divmod, params.stride_w_divmod, - start_r, start_s, - MatrixCoord( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK - ) - ); - - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_B, - thread_idx, - start_r, start_s, - MatrixCoord( - threadblock_tile_idx.k() * Mma::Shape::kK, - threadblock_tile_idx.n() * Mma::Shape::kN - ) - ); - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Construct the semaphore. - int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); - } - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - params.stride_h_divmod, params.stride_w_divmod, - start_r, start_s, - threadblock_offset - ); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - if (output_op.is_source_needed()) - { - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - params.stride_h_divmod, params.stride_w_divmod, - start_r, start_s, - threadblock_offset); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_idx.k()); - } - - // Run epilogue with addend source iterator - epilogue(output_op, iterator_D, accumulators, iterator_C); - } - else - { - // Run epilogue without addend source iterator - epilogue(output_op, iterator_D, accumulators); - } - - // - // Release the semaphore - // - - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_idx.k() + 1; - } - - semaphore.release(lock); - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h deleted file mode 100644 index 0113473f9b28d7c657c07ff8f85e34fc66ea1ed1..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h +++ /dev/null @@ -1,494 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Convolution kernel with an epilogue that computes the absolute maximum value of the output - and a pre-activation-function auxiliary output. The auxiliary output is also (optionally) - stored to global memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -> -struct ImplicitGemmConvolutionWithAbsMax { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueOutputOp::ElementOutput; - - /// Set output tensor C layout - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename ArchMmaOperator::Shape; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefC = cutlass::TensorRef; - using TensorRefAux = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; - - /// Wgrad C stride idx for implicit gemm algorithm - // Conv2d row-major matrix C (KxRSC) - // Conv3d row-major matrix C (KxTRSC) - static int const kWgradCStrideIdx = - platform::is_same::value ? 2 : 3; - - /// This chooses the appropriate stride element of the C tensor. - static int const kTensorCStrideIdx = - (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefC ref_C; - TensorRefC ref_D; - TensorRefC ref_Aux; - - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; - - void * ptr_Vector; - - typename LayoutC::Stride::Index ldr; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - TensorRefAux const & ref_Aux, - typename EpilogueOutputOp::Params const & output_op, - SplitKMode const & split_k_mode = SplitKMode::kSerial, - void * ptr_Vector = nullptr, - typename LayoutC::Stride::Index ldr = 0 - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_C(ref_C), - ref_D(ref_D), - ref_Aux(ref_Aux), - output_op(output_op), - split_k_mode(split_k_mode), - ptr_Vector(ptr_Vector), - ldr(ldr) - { - - } - - }; - - /// Parameters structure - struct Params { - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - gemm::GemmCoord implicit_gemm_problem_size; - int swizzle_log_tile; - - int gemm_k_iterations; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename Epilogue::AuxOutputTileIterator::Params iterator_Aux; - typename Epilogue::AuxOutputTileIterator::Element *ptr_Aux; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; - - void * ptr_Vector; - typename LayoutC::Stride::Index ldr; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - gemm_k_iterations(0), - ptr_Vector(nullptr), - ldr(0) - { } - - /// - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - int *semaphore = nullptr - ): - problem_size(args.problem_size), - implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(args.problem_size, args.ref_B.layout()), - ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), - ptr_D(args.ref_D.data()), - iterator_Aux(ConvOutputIteratorParameter::layout(args.ref_Aux)), - ptr_Aux(args.ref_Aux.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode), - ptr_Vector(args.ptr_Vector), - ldr(args.ldr) - - { - gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - implicit_gemm_problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - ImplicitGemmConvolutionWithAbsMax() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - MatrixCoord( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK - ) - ); - - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_B, - thread_idx, - MatrixCoord( - threadblock_tile_idx.k() * Mma::Shape::kK, - threadblock_tile_idx.n() * Mma::Shape::kN - ) - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Construct the semaphore. - int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); - } - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to auxiliary tensor. - typename Epilogue::AuxOutputTileIterator iterator_Aux( - params.iterator_Aux, - params.ptr_Aux, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector *ptr_Vector = - static_cast(params.ptr_Vector); - - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; - } - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_idx.k()); - - } - // Each split-k-slice writes to a unique tensor location - else if (params.split_k_mode == SplitKMode::kParallel) { - iterator_D.add_pointer_offset(threadblock_tile_idx.k() * - cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - // Only the final block uses Vector - ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) - ? nullptr - : ptr_Vector, - iterator_D, - accumulators, - iterator_C, - iterator_Aux, - ConvOutputIteratorParameter::extent(params.problem_size), - threadblock_offset); - - // - // Release the semaphore - // - - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_idx.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h deleted file mode 100644 index 1e810e3d13c8b8eed4894ac9670f4a586dcaef8d..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ /dev/null @@ -1,499 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined Implicit GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/semaphore.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -> -struct ImplicitGemmConvolutionWithFusedEpilogue { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static Operator const kConvolutionalOperator = ConvOperator; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueOutputOp::ElementOutput; - - /// Set output tensor C layout - using LayoutC = LayoutA; - - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using WarpMmaOperator = typename Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; - - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename WarpMmaOperator::Shape; - using InstructionShape = typename ArchMmaOperator::Shape; - - static int const kStages = Mma::kStages; - static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using TensorRefA = typename Mma::IteratorA::TensorRef; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using TensorRefC = cutlass::TensorRef; - - /// Check iterator A and B convolution dimension are the same and - // set device::ImplicitGemmConvolution::kConvDim - static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, - "Convolution on different different dimensions is not supported"); - static int const kConvDim = Mma::IteratorA::kConvDim; - - /// Conv dimension and problem size structure (Conv2d or Conv3d) - using ConvProblemSize = ConvProblemSize_; - - static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; - - /// Wgrad C stride idx for implicit gemm algorithm - // Conv2d row-major matrix C (KxRSC) - // Conv3d row-major matrix C (KxTRSC) - static int const kWgradCStrideIdx = - platform::is_same::value ? 2 : 3; - - /// This chooses the appropriate stride element of the C tensor. - static int const kTensorCStrideIdx = - (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); - - // - // - // - using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< - LayoutC, - typename Epilogue::OutputTileIterator::Layout, - TensorRefC, - ConvOperator, - ConvProblemSize - >; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefC ref_C; - TensorRefC ref_D; - - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; - - void * ptr_Vector; - void * ptr_Tensor; - - typename LayoutC::Stride::Index ldr; - typename LayoutC::Stride::Index ldt; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - ConvProblemSize const & problem_size, - TensorRefA const & ref_A, - TensorRefB const & ref_B, - TensorRefC const & ref_C, - TensorRefC const & ref_D, - typename EpilogueOutputOp::Params const & output_op, - SplitKMode const & split_k_mode = SplitKMode::kSerial, - void * ptr_Vector = nullptr, - void * ptr_Tensor = nullptr, - typename LayoutC::Stride::Index ldr = 0, - typename LayoutC::Stride::Index ldt = 0 - ): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_C(ref_C), - ref_D(ref_D), - output_op(output_op), - split_k_mode(split_k_mode), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - ldr(ldr), - ldt(ldt) - { - - } - - }; - - /// Parameters structure - struct Params { - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - gemm::GemmCoord implicit_gemm_problem_size; - int swizzle_log_tile; - - int gemm_k_iterations; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; - - typename Epilogue::TensorTileIterator::Params params_Tensor; - void * ptr_Vector; - typename LayoutC::Stride::Index ldr; - void * ptr_Tensor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - gemm_k_iterations(0), - ptr_Vector(nullptr), - ldr(0), - ptr_Tensor(nullptr) - { } - - /// - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - int *semaphore = nullptr - ): - problem_size(args.problem_size), - implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), - iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), - ptr_A(args.ref_A.data()), - iterator_B(args.problem_size, args.ref_B.layout()), - ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), - ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), - ptr_D(args.ref_D.data()), - output_op(args.output_op), - semaphore(semaphore), - split_k_mode(args.split_k_mode), - params_Tensor(args.ldt), - ptr_Vector(args.ptr_Vector), - ldr(args.ldr), - ptr_Tensor(args.ptr_Tensor) - - { - gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - implicit_gemm_problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.problem_size.split_k_slices); - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - ImplicitGemmConvolutionWithFusedEpilogue() { } - - /// Executes one ImplicitGEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || - params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { - - return; - } - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.iterator_A, - params.problem_size, - params.ptr_A, - thread_idx, - MatrixCoord( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK - ) - ); - - typename Mma::IteratorB iterator_B( - params.iterator_B, - params.problem_size, - params.ptr_B, - thread_idx, - MatrixCoord( - threadblock_tile_idx.k() * Mma::Shape::kK, - threadblock_tile_idx.n() * Mma::Shape::kN - ) - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // Construct the semaphore. - int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // Compute logical position within grid - threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); - } - - MatrixCoord threadblock_offset( - threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor - typename Epilogue::OutputTileIterator iterator_D( - params.iterator_D, - params.ptr_D, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset - ); - - typename Epilogue::ElementTensor *ptr_Tensor = - static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector *ptr_Vector = - static_cast(params.ptr_Vector); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) - ? nullptr - : ptr_Tensor, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - threadblock_offset); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; - } - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_idx.k()); - - } - // Each split-k-slice writes to a unique tensor location - else if (params.split_k_mode == SplitKMode::kParallel) { - iterator_D.add_pointer_offset(threadblock_tile_idx.k() * - cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - // Only the final block uses Vector - ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) - ? nullptr - : ptr_Vector, - iterator_D, - accumulators, - iterator_C, - tensor_iterator, - ConvOutputIteratorParameter::extent(params.problem_size), - threadblock_offset); - - // - // Release the semaphore - // - - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_idx.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp deleted file mode 100644 index 327fc27db4eba8093ce58845e465071da724c2e8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ /dev/null @@ -1,874 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" - -#include "cute/tensor.hpp" -#include "cute/arch/tmem_allocator_sm100.hpp" -#include "cute/arch/cluster_sm90.hpp" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/grid_dependency_control.h" -#include "cutlass/conv/detail.hpp" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/pipeline/sm100_pipeline.hpp" -#include "cutlass/detail/sm100_tmem_helper.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileSchedulerTag_ -> -class ConvUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileSchedulerTag_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - - // Mainloop derived types - using ProblemShape = ProblemShape_; - using CollectiveMainloop = CollectiveMainloop_; - - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; - using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; - static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; - static constexpr bool is_grouped_wgrad = CollectiveMainloop::is_grouped_wgrad; - static constexpr bool IsComplex = false; - static_assert(ArchTag::kMinComputeCapability >= 100); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; - // TileID scheduler - // CLC pipeline depth determines how many waves (stages-1) the scheduler can race ahead - static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; - static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; - - using TileSchedulerTag = TileSchedulerTag_; - using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< - TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr bool IsDynamicCluster = not cute::is_static_v; - - // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; - static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; - - static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + - NumMainloopLoadThreads + NumMMAThreads + - NumEpilogueLoadThreads + NumEpilogueThreads; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr uint32_t NumFixupBarriers = 1; - - // Pipelines and pipeline states - static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); - - // Pipeline and pipeline state types - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; - - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; - - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; - - using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - using AccumulatorPipeline = cutlass::PipelineUmmaAsync; - using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - - using CLCPipeline = cutlass::PipelineCLCFetchAsync; - using CLCPipelineState = cutlass::PipelineDetail::PipelineCLCFetchAsyncPipelineState; - using CLCPipelineSharedStorage = cutlass::PipelineDetail::PipelineCLCFetchAsyncSharedStorage; - - using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, - cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; - - // Kernel level shared memory storage - struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; - using CLCPipelineStorage = CLCPipelineSharedStorage; - using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) LoadOrderBarrierStorage load_order; - alignas(16) CLCPipelineStorage clc; - alignas(16) AccumulatorPipelineStorage accumulator; - alignas(16) arch::ClusterBarrier tmem_dealloc; - } pipelines; - - alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; - uint32_t tmem_base_ptr; - - struct TensorStorage : cute::aligned_struct<128, _1> { - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - - EpilogueTensorStorage epilogue; - MainloopTensorStorage mainloop; - } tensors; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); - - // Host facing host arguments - struct Arguments { - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel device entry point API - struct Params { - using ProblemShapeMNKL = decltype(CollectiveMainloop::get_problem_shape_MNKL(ProblemShape{})); - ProblemShapeMNKL problem_shape; - MainloopParams mainloop; - EpilogueParams epilogue; - TileSchedulerParams scheduler; - KernelHardwareInfo hw_info{}; - }; - - enum class WarpCategory : int32_t { - MMA = 0, - Sched = 1, - MainloopLoad = 2, - EpilogueLoad = 3, - Epilogue = 4 - }; - - struct IsParticipant { - uint32_t mma = false; - uint32_t sched = false; - uint32_t main_load = false; - uint32_t epi_load = false; - uint32_t epilogue = false; - }; - - // - // Methods - // - // Map user facing arguments to device facing params - CUTLASS_HOST - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - static constexpr uint32_t NumEpilogueSubTiles = 1; - - auto problem_shape_mnkl = CollectiveMainloop::get_problem_shape_MNKL(args.problem_shape); - - auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace, args.hw_info); - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - // Epilogue - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - // Tile scheduler - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, problem_shape_mnkl, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - return { - problem_shape_mnkl, - mainloop_params, - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), - TileScheduler::to_underlying_arguments( - args.problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, - args.hw_info, args.scheduler, scheduler_workspace), - args.hw_info - }; - } - - CUTLASS_HOST - static bool - can_implement(Arguments const& args) { - bool implementable = true; - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - if constexpr (IsDynamicCluster) { - static constexpr int MaxClusterSize = 16; - implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; - implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; - implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); - } - - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); - auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); - - // implicit gemm B tile can be small for conv, ensure multicast smem offsets are 128B aligned - int multicast_b_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape)) * sizeof_bits_v; - int multicast_b_fallback_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape_fallback)) * sizeof_bits_v; - implementable &= multicast_b_bits % (128*8) == 0 && multicast_b_fallback_bits % (128*8) == 0; - if (not implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: multicast size too large for B tile\n"); - return false; - } - - if constexpr (is_grouped_wgrad) { - implementable &= size<0>(cluster_shape) == 1 && size<0>(cluster_shape_fallback) == 1; - - if (!implementable) { - return false; - } - } - - return implementable; - } - - CUTLASS_HOST - static size_t - get_workspace_size(Arguments const& args) { - static constexpr uint32_t NumEpilogueSubTiles = 1; - size_t workspace_size = 0; - auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); - - // Epilogue - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - // Tile scheduler - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - CUTLASS_HOST - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - static constexpr uint32_t NumEpilogueSubTiles = 1; - auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - // Epilogue - status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - // Tile scheduler - status = TileScheduler::template initialize_workspace - ( - args.scheduler, workspace_ptr + workspace_offset, stream, linear_problem_shape_MNKL, - args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); - - workspace_offset += TileScheduler::template get_workspace_size - ( - args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, - CollectiveEpilogue::NumAccumulatorMtxs); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - CUTLASS_HOST - static dim3 - get_grid_shape(Params const& params) { - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); - - return TileScheduler::get_grid_shape( - params.scheduler, - params.problem_shape, - TileShape{}, - AtomThrShapeMNK{}, - cluster_shape - ,params.hw_info - ); - } - - CUTLASS_HOST - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - using namespace cute; - using X = Underscore; - - // Separate out problem shape for convenience - auto problem_shape_MNKL = append<4>(params.problem_shape, _1{}); - auto [M, N, K, L] = problem_shape_MNKL; - - // Account for more than one epilogue warp - int warp_idx = canonical_warp_idx_sync(); - WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) - : WarpCategory::Epilogue; - - uint32_t lane_predicate = cute::elect_one_sync(); - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); - int cluster_size = size(cluster_shape); - uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); - bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; - int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); - bool is_mma_leader_cta = cta_coord_v == 0; - constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; - [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_category == WarpCategory::Sched) && lane_predicate) { - collective_mainloop.prefetch_tma_descriptors(); - } - if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { - collective_epilogue.prefetch_tma_descriptors(params.epilogue); - } - - // Do we load source tensor C or other aux inputs - bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - IsParticipant is_participant = { - (warp_category == WarpCategory::MMA), // mma - (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched - (warp_category == WarpCategory::MainloopLoad), // main_load - (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load - (warp_category == WarpCategory::Epilogue) // epilogue - }; - - // Mainloop Load pipeline - typename MainloopPipeline::Params mainloop_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (WarpCategory::MMA == warp_category) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - mainloop_pipeline_params.initializing_warp = 0; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, - mainloop_pipeline_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::false_type{}); // Delay mask calculation - - // Epilogue Load pipeline - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (WarpCategory::EpilogueLoad == warp_category) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (WarpCategory::Epilogue == warp_category) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; - epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; - epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - epi_load_pipeline_params.initializing_warp = 1; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Load order barrier - typename LoadOrderBarrier::Params load_order_barrier_params; - load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; - load_order_barrier_params.group_size = NumMainloopLoadThreads; - load_order_barrier_params.initializing_warp = 3; - LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); - - // CLC pipeline - typename CLCPipeline::Params clc_pipeline_params; - if (WarpCategory::Sched == warp_category) { - clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; - } - else { - clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; - } - clc_pipeline_params.producer_blockid = 0; - clc_pipeline_params.producer_arv_count = 1; - clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * - (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); - if (is_epi_load_needed) { - clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; - } - clc_pipeline_params.transaction_bytes = CLCResponseSize; - clc_pipeline_params.initializing_warp = 4; - CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); - - // Mainloop-Epilogue pipeline - typename AccumulatorPipeline::Params accumulator_pipeline_params; - if (WarpCategory::MMA == warp_category) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; - } - if (WarpCategory::Epilogue == warp_category) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; - } - // Only one producer thread arrives on this barrier. - accumulator_pipeline_params.producer_arv_count = 1; - accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; - accumulator_pipeline_params.initializing_warp = 5; - AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, - accumulator_pipeline_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::false_type{}); // Delay mask calculation - - // Tmem allocator - TmemAllocator tmem_allocator{}; - - // Sync allocation status between MMA and epilogue warps within CTA - arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); - // Sync deallocation status between MMA warps of peer CTAs - arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; - [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; - if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { - tmem_deallocation_result_barrier.init(NumMMAThreads); - } - - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer threadblocks in the cluster - pipeline_init_arrive_relaxed(cluster_size); - - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); - - uint32_t tmem_stage_ptrs[AccumulatorPipelineStageCount]; - MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - - EpiLoadPipelineState epi_load_pipe_consumer_state; - EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - - // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - CLCPipelineState clc_pipe_consumer_state; - CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); - - AccumulatorPipelineState accumulator_pipe_consumer_state; - AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); - - dim3 block_id_in_cluster = cute::block_id_in_cluster(); - - // Calculate mask after cluster barrier arrival - mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); - accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); - - // TileID scheduler - TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, problem_shape_MNKL, TileShape{}, block_id_in_cluster); - typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); - auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - auto acc_shape = collective_mainloop.partition_accumulator_shape(); - auto accumulators = TiledMma::make_fragment_C(acc_shape); - - int TmemColumnsPerAccumulatorTile = cutlass::detail::find_tmem_tensor_col_offset(accumulators); - pipeline_init_wait(cluster_size); - - if (is_participant.main_load) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - bool do_load_order_arrive = is_epi_load_needed; - Tensor gA_mk = get<0>(load_inputs); - - do { - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, TileShape{}, shape<3>(gA_mk)); - auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); - auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); - - auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - cta_coord_mnkl, - k_tile_iter, k_tile_prologue - ); - mainloop_pipe_producer_state = mainloop_producer_state_next; - - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - cta_coord_mnkl, - k_tile_iter_next, k_tile_count - k_tile_prologue - ); - mainloop_pipe_producer_state = mainloop_producer_state_next_; - - // Sync warp to prevent non-participating threads entering next wave early - __syncwarp(); - - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - } while (work_tile_info.is_valid()); - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - } - - else if (is_participant.sched) { - // Whether a new CLC query must be performed. - // See comment below where this variable is updated for a description of - // why this variable is needed. - bool requires_clc_query = true; - - do { - if (requires_clc_query) { - // Query next clcID and update producer state - clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); - } - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - // Only perform a new CLC query if we consumed a new CLC query result in - // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does - // not consume a new CLC query response is when processing stream-K units. - // The current stream-K scheduler uses single WorkTileInfo to track multiple - // (potentially-partial) tiles to be computed via stream-K. In this case, - // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, - // rather than consuming a CLC query response. - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - work_tile_info = next_work_tile_info; - } while (work_tile_info.is_valid()); - clc_pipeline.producer_tail(clc_pipe_producer_state); - } - - else if (is_participant.mma) { - // Tmem allocation sequence - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); - __syncwarp(); - tmem_allocation_result_barrier.arrive(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - - CUTLASS_PRAGMA_UNROLL - for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { - tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; - } - auto mma_inputs = collective_mainloop.mma_init(shared_storage.tensors.mainloop); - - do { - auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - // Wait for tmem accumulator buffer to become empty with a flipped phase - if (is_mma_leader_cta) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - } - - // Accumulator stage slice - int acc_stage = accumulator_pipe_producer_state.index(); - accumulators.data() = tmem_stage_ptrs[acc_stage]; - - if (is_mma_leader_cta) { - mainloop_pipe_consumer_state = collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - mma_inputs, - k_tile_count - ); - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); - } - ++accumulator_pipe_producer_state; - work_tile_info = next_work_tile_info; - } while (work_tile_info.is_valid()); - - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - // Release the right to allocate before deallocations so that the next CTA can rasterize - tmem_allocator.release_allocation_lock(); - - // Leader MMA waits for leader + peer epilogues to release accumulator stage - if (is_mma_leader_cta) { - accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); - } - // Signal to peer MMA that entire tmem allocation can be deallocated - if constexpr (has_mma_peer_cta) { - // Leader does wait + arrive, follower does arrive + wait - tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); - tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); - tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); - } - - // Free entire tmem allocation - tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } - - else if (is_participant.epi_load) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - bool do_load_order_wait = true; - bool do_tail_load = false; - - do { - bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); - - // Get current work tile and fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - if (compute_epilogue) { - if (do_load_order_wait) { - load_order_barrier.wait(); - do_load_order_wait = false; - } - - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - TileShape{}, - TiledMma{}, - shared_storage.tensors.epilogue - ); - - do_tail_load = true; - } - - // Calculate the cta coordinates of the next work tile - cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - } while (work_tile_info.is_valid()); - - // Only perform a tail load if one of the work units processed performed - // an epilogue load. An example of a case in which a tail load should not be - // performed is in split-K if a cluster is only assigned non-final splits (for which - // the cluster does not compute the epilogue). - if (do_tail_load) { - collective_epilogue.load_tail( - epi_load_pipeline, epi_load_pipe_producer_state, - epi_store_pipeline, epi_store_pipe_producer_state); - } - } - - else if (is_participant.epilogue) { - // Wait for tmem allocate here - tmem_allocation_result_barrier.arrive_and_wait(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - CUTLASS_PRAGMA_UNROLL - for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { - tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; - } - - bool do_tail_store = false; - do { - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - // Accumulator stage slice after making sure allocation has been performed - int acc_stage = accumulator_pipe_consumer_state.index(); - accumulators.data() = tmem_stage_ptrs[acc_stage]; - - accumulator_pipe_consumer_state = scheduler.template fixup( - TiledMma{}, - work_tile_info, - accumulators, - accumulator_pipeline, - accumulator_pipe_consumer_state, - typename CollectiveEpilogue::CopyOpT2R{} - ); - - // - // Epilogue and write to gD - // - if (scheduler.compute_epilogue(work_tile_info)) { - auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - accumulator_pipeline, - accumulator_pipe_consumer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - TileShape{}, - TiledMma{}, - accumulators, - shared_storage.tensors.epilogue - ); - epi_load_pipe_consumer_state = load_state_next; - epi_store_pipe_producer_state = store_state_next; - accumulator_pipe_consumer_state = acc_state_next; - do_tail_store = true; - } - work_tile_info = next_work_tile_info; - cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - } while (work_tile_info.is_valid()); - - // Only perform a tail store if one of the work units processed performed - // an epilogue. An example of a case in which a tail load should not be - // performed is in split-K if a cluster is only assigned non-final splits (for which - // the cluster does not compute the epilogue). - if (do_tail_store) { - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state, - epi_store_pipeline, epi_store_pipe_producer_state, - CtaShape_MNK{}); - } - } - - else { - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp deleted file mode 100644 index 2c02a4531edd4078da6c92205f36b62b237c20bc..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" - -#include "cute/tensor.hpp" -#include "cute/arch/cluster_sm90.hpp" - -#include "cutlass/conv/detail.hpp" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/dispatch_policy.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/pipeline/sm90_pipeline.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class ConvUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t> -> : public cutlass::gemm::kernel::GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_ -> -{}; -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::kernel - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h deleted file mode 100644 index 41eaba2f64b1c14fd85de632b1bfe8c9a3efbc1e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h +++ /dev/null @@ -1,325 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for depthwise convolution -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/thread/mma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// MMA operation -template < - /// Size of the matrix product (concept: GemmShape) - typename Shape_, - /// Number of threads participating - int kThreads_, - /// Data type of A elements - typename ElementA, - /// Data type of B elements - typename ElementB, - /// Element type of C matrix - typename ElementC, - /// Inner product operator - typename Operator -> -struct ElementwiseInnerProduct; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// General implementation -template < - /// Size of the matrix product (concept: GemmShape) - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Data type of B elements - typename ElementB_, - /// Element type of C matrix - typename ElementC_> -struct ElementwiseInnerProduct { - using Shape = Shape_; - using Operator = arch::OpMultiplyAdd; - using ElementC = ElementC_; - - CUTLASS_HOST_DEVICE - void operator()(Array &d, - Array const &a, - Array const &b, - Array const &c) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Shape::kN; ++i) { - d[i] = a[i] * b[i] + c[i]; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Specialization of half_t -template <> -struct ElementwiseInnerProduct< - gemm::GemmShape<2, 2, 1>, - 1, - half_t, - half_t, - half_t, - arch::OpMultiplyAdd> { - - using Shape = gemm::GemmShape<2, 2, 1>; - using Operator = arch::OpMultiplyAdd; - using ElementC = half_t; - - CUTLASS_HOST_DEVICE - void operator()( - Array &d, - Array const &a, - Array const &b, - Array const &c - ) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) - - __half2 const & A = reinterpret_cast<__half2 const &>(a); - __half2 const & B = reinterpret_cast<__half2 const &>(b); - __half2 const & C = reinterpret_cast<__half2 const &>(c); - - __half2 tmp_D = __hfma2(A, B, C); - - d = reinterpret_cast const &>(tmp_D); - -#else - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - d[i] = a[i] * b[i] + c[i]; - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - /// Data type of A elements - typename ElementA, - /// Data type of B elements - typename ElementB, - /// Element type of C matrix - typename ElementC, - /// Concept: arch::OpMultiplyAdd or arch::Mma<> - typename Operator = arch::OpMultiplyAdd, - /// Used for partial specialization - typename Enable = bool -> -struct DepthwiseDirectConvElementwiseInnerProduct; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gemplate that handles all packed matrix layouts -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Data type of B elements - typename ElementB_, - /// Element type of C matrix - typename ElementC_, - /// Operator used to compute GEMM - typename Operator_ -> -struct DepthwiseDirectConvElementwiseInnerProductGeneric { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = ElementA_; - - /// Data type of operand B - using ElementB = ElementB_; - - /// Element type of operand C - using ElementC = ElementC_; - - /// Underlying mathematical operator - using Operator = Operator_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Instruction - using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< - gemm::GemmShape, - 1, - ElementA, - ElementB, - ElementC, - Operator>; - - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = - reinterpret_cast const *>(&A); - Array const *ptr_B = - reinterpret_cast const *>(&B); - - MmaOp mma_op; - - // Copy accumulators - D = C; - - // Compute matrix product - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { - - Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; - Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; - Array tmpB = ptr_B[n]; - - mma_op(tmpD, tmpA, tmpB, tmpD); - - ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; - - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Data type of B elements - typename ElementB_, - /// Element type of C matrix - typename ElementC_ -> -struct DepthwiseDirectConvElementwiseInnerProduct< - Shape_, - ElementA_, - ElementB_, - ElementC_, - arch::OpMultiplyAdd - > { - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = ElementA_; - - /// Data type of operand B - using ElementB = ElementB_; - - /// Element type of operand C - using ElementC = ElementC_; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = - Array; // output_tile_size per thread * groups_per_thread - - /// B operand storage - using FragmentB = Array; // 1 * groups_per_thread - - /// C operand storage - using FragmentC = - Array; // output_tile_size per thread * groups_per_thread - - static bool const use_optimized = 0; - - using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - ArchMmaOperator mma; - - mma(D, A, B, C); - - } -}; - -} // namespace thread -} // namespace conv -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h deleted file mode 100644 index 2da2b73b3afe3d5f5800c84d2edb2b220003ba83..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ /dev/null @@ -1,485 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dDgradFilterTileAccessIteratorAnalytic; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs -// on non-contributing w positions -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradFilterTileAccessIteratorAnalytic < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kStrided, - AccessType_ -> { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or larger."); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension - int filter_r_; - int filter_s_; - int start_r_; - int start_s_; - int offset_k_[ThreadMap::Iterations::kStrided]; - int offset_c_[ThreadMap::Iterations::kContiguous]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_r_(start_r), - filter_s_(start_s), - start_r_(start_r), - start_s_(start_s) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = - threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // Moves filter_s - filter_s_ += problem_size_.stride_w; - if (filter_s_ < problem_size_.S) { - return; - } - // Restore filter_s - filter_s_ = start_s_; - - // Move filter_r - filter_r_ += problem_size_.stride_h; - if (filter_r_ < problem_size_.R) { - return; - } - // Restore filter_r - filter_r_ = start_r_; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the filter tensor w that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int k = offset_k_[iteration_strided_]; - int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; - - return TensorCoord(k, filter_r_, filter_s_, c); - } - - /// Returns true if the current coordinate is within the filter tensor w - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad -// on problem sizes with stride = {1x1} -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradFilterTileAccessIteratorAnalytic < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kUnity, - AccessType_ ->{ -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or larger."); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension - int filter_r_; - int filter_s_; - int offset_k_[ThreadMap::Iterations::kStrided]; - int offset_c_[ThreadMap::Iterations::kContiguous]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = - threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the filter tensor w that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int k = offset_k_[iteration_strided_]; - int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; - - return TensorCoord(k, filter_r_, filter_s_, c); - } - - /// Returns true if the current coordinate is within the filter tensor w - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h deleted file mode 100644 index 8a5e60b9d134d8ec5d28da7e486bc5c7f6629a39..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ /dev/null @@ -1,619 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dDgradFilterTileAccessIteratorOptimized; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -// on problem sizes with stride = {1x1} -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradFilterTileAccessIteratorOptimized < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kStrided, - AccessType_ - > { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Parameters structure - // - - struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams { - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base): - Conv2dStridedDgradFilterIteratorOptimizedParams(base) { } - - CUTLASS_HOST_DEVICE - Params( - Conv2dProblemSize const &problem_size, - Layout const &layout - ): - Conv2dStridedDgradFilterIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ) { } - - }; - -private: - - Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - uint32_t predicates_[kAccessesPerVector]; - int filter_k_; - int filter_r_; - int filter_s_; - - int start_r_; - int start_s_; - - int64_t reset_bytes_s_; - int64_t reset_bytes_r_; - - // - // Assertions - // - - // We map predicates into bits packed in this uint32_t container - static_assert(ThreadMap::Iterations::kStrided * - ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, - "Currently, the number of loads per iteration is limited by the size of the predicates container."); - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorOptimized( - Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_{0}, - filter_r_(start_r), - filter_s_(start_s), - start_r_(start_r), - start_s_(start_s) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.row() + thread_coord.strided(); - Index column = threadblock_offset.column() + thread_coord.contiguous(); - - reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; - reset_bytes_r_ = reset_bytes_s_ + - (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; - int filter_c = column + c * ThreadMap::Delta::kContiguous; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_[v] |= (pred << pred_idx); - } - } - } - - TensorCoord coord{filter_k_, filter_r_, filter_s_, column}; - - pointer_ += params_.layout(coord) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_DEVICE - void advance() { - - int next_idx = 0; - LongIndex reset_bytes = params_.reset_bytes; - - // Move filter_s by stride_w - filter_s_ += problem_size_.stride_w; - if (filter_s_ >= problem_size_.S) { - - // Restore filter_s - filter_s_ = start_s_; - - // Move filter_r by stride_h - filter_r_ += problem_size_.stride_h; -#if 0 - bool check = (filter_r_ < problem_size_.R); - - filter_r_ = check ? filter_r_ : start_r_; - next_idx = check ? 1 : 2; - reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); -#else - asm volatile( - "{\n\t" - " .reg .pred %%p;\n\t" - " .reg .s64 t1;\n\t" - " setp.lt.s32 %%p, %3, %4;\n\t" - " selp.s32 %0, %3, %5, %%p;\n\t" - " selp.s32 %1, 1, 2, %%p;\n\t" - " selp.s64 t1, %6, %7, %%p;\n\t" - " add.s64 %2, %8, t1;\n\t" - "}\n" - : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) - : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), - "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); -#endif - } - - // offset pointers by offset_bytes - pointer_ += (params_.inc_next[next_idx] - reset_bytes); - - if (next_idx == 2) { - filter_k_ += params_.filter_k_delta; - } - - // Clear predicates if needed - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - predicates_[v] = (predicates_[v] & (~kClearMask)); - } - } - } - } - - /// Returns true if the current coordinate is within the filter tensor W - CUTLASS_HOST_DEVICE - bool valid() { - LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_[iteration_vector_] & (1u << pred_idx)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - return reinterpret_cast(pointer_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - - // Move to the next K coordinate within the tile - pointer_ += params_.inc_next_strided; - - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -// on problem sizes with stride = {1x1} -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradFilterTileAccessIteratorOptimized < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kUnity, - AccessType_ - > { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Parameters structure - // - - struct Params : Conv2dDgradFilterIteratorOptimizedParams { - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Conv2dDgradFilterIteratorOptimizedParams const &base): - Conv2dDgradFilterIteratorOptimizedParams(base) { } - - CUTLASS_HOST_DEVICE - Params( - Conv2dProblemSize const &problem_size, - Layout const &layout - ): - Conv2dDgradFilterIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ) { } - - }; - -private: - - Conv2dDgradFilterIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - uint32_t predicates_[kAccessesPerVector]; - int filter_rs_; - int filter_k_; - - // - // Assertions - // - - // We map predicates into bits packed in this uint32_t container - static_assert(ThreadMap::Iterations::kStrided * - ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, - "Currently, the number of loads per iteration is limited by the size of the predicates container."); - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorOptimized( - Conv2dDgradFilterIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_{0}, - filter_rs_(0), - filter_k_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.row() + thread_coord.strided(); - Index column = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; - int filter_c = column + c * ThreadMap::Delta::kContiguous; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_[v] |= (pred << pred_idx); - } - } - } - - pointer_ += ( - filter_k_ * params.layout.stride()[2] + column - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - LongIndex next = params_.inc_next_rs; - - // moves to the next tile - ++filter_rs_; - if (filter_rs_ == params_.RS) { - - filter_rs_ = 0; - next = params_.inc_next_k; - filter_k_ += params_.filter_k_delta; - } - - // Clear predicates if needed - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - predicates_[v] = (predicates_[v] & (~kClearMask)); - } - } - } - - pointer_ += next; - } - - /// Returns true if the current coordinate is within the filter tensor W - CUTLASS_HOST_DEVICE - bool valid() { - LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_[iteration_vector_] & (1u << pred_idx)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - return reinterpret_cast(pointer_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - - // Move to the next K coordinate within the tile - pointer_ += params_.inc_next_strided; - - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h deleted file mode 100644 index b33645c1783c8b12cc9d8d6e1d93dbffb3f47f1c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ /dev/null @@ -1,606 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/functional.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -// unscaled coordinations -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kStrided, - AccessType_ -> { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or greater."); - - // - // Simpligying assertions - // - - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_k_; - int filter_r_; - int filter_s_; - int start_r_; - int start_s_; - - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_p_[ThreadMap::Iterations::kStrided]; - int offset_q_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_k_(0), - filter_r_(start_r), - filter_s_(start_s), - start_r_(start_r), - start_s_(start_s) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - int filter_r = filter_r_; - int filter_s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - filter_r = (problem_size_.R - 1 - filter_r); - filter_s = (problem_size_.S - 1 - filter_s); - } - - // Starting h, w positions for filter position in gemm_k=0 - int start_h, start_w; - strided_dgrad_starting_coords( - problem_size_, - stride_h_divmod, stride_w_divmod, - filter_r, filter_s, - start_h, start_w); - - // Effective P and Q for filter position required for remapping NHW rows - int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; - int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; - - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; - - // (STEP 1) [reorder NHW rows to start with same filter positions] - offset_n_[s] = offset_npq / (P * Q); - int residual = offset_npq % (P * Q); - - int p = (residual / Q); - int q = (residual % Q); - - int mapped_h = (start_h + p * problem_size_.stride_h); - int mapped_w = (start_w + q * problem_size_.stride_w); - - // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 - // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible - // by stride_h and stride_w - offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; - offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; - } - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - // Move filter_s by stride_w - filter_s_ += problem_size_.stride_w; - if (filter_s_ < problem_size_.S) { - return; - } - - // Restore filter_s - filter_s_ = start_s_; - - // Move filter_r by stride_h - filter_r_ += problem_size_.stride_h; - if (filter_r_ < problem_size_.R) { - return; - } - - // Restore filter_r - filter_r_ = start_r_; - - // Move filter_k - filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the output tensor Dy that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int n = offset_n_[iteration_strided_]; - int p = offset_p_[iteration_strided_]; - int q = offset_q_[iteration_strided_]; - - int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); - - p += (conv_sign * (filter_r_ / problem_size_.stride_h)); - q += (conv_sign * (filter_s_ / problem_size_.stride_w)); - - int k = filter_k_ + iteration_vector_ * AccessType::kElements; - - return TensorCoord( - n, - p, - q, - k); - } - - - /// Returns true if the current coordinate is within the output tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return - coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.P && - coord.w() >= 0 && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by -// eliminating modulo arithmetic to compute unscaled coordinates -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kUnity, - AccessType_ -> { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or greater."); - - // - // Simpligying assertions - // - - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - struct Params { - - Layout layout; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - Conv2dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - } - }; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_k_; - int filter_r_; - int filter_s_; - - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_w_[ThreadMap::Iterations::kStrided]; - int offset_h_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_k_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); - int residual = offset_nhw % (problem_size_.H * problem_size_.W); - - offset_h_[s] = residual / problem_size_.W; - offset_w_[s] = residual % problem_size_.W; - } - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // move to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - - filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the output tensor Dy that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int n = offset_n_[iteration_strided_]; - int h = offset_h_[iteration_strided_]; - int w = offset_w_[iteration_strided_]; - - int r = filter_r_; - int s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h; - int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w; - - int k = filter_k_ + iteration_vector_ * AccessType::kElements; - - return TensorCoord(n, p, q, k); - } - - /// Returns true if the current coordinate is within the output tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.P && - coord.w() >= 0 && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization - // only supports (stride_h, stride_w) = (1, 1) - if (problem_size.stride() != MatrixCoord({1, 1})) { - return Status::kErrorNotSupported; - } - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h deleted file mode 100644 index 638c6607095ce85f7c1b135296d974bdf295621a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ /dev/null @@ -1,821 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dDgradOutputGradientTileAccessIteratorOptimized; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling -// to skip MMAs (Dx = Dy * w) on invalid filter positions -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradOutputGradientTileAccessIteratorOptimized < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kStrided, - AccessType_ -> { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - using Mask = uint64_t; - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or greater."); - - // - // Simpligying assertions - // - - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - - // One pointer per access - char const *pointer_[ThreadMap::Iterations::kStrided]; - - int filter_k_; - int filter_r_; - int filter_s_; - int start_r_; - int start_s_; - int64_t reset_bytes_s_; - int64_t reset_bytes_r_; - - Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorOptimized( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles - ): - params_(params), - problem_size_(problem_size), - filter_k_(0), - filter_r_(start_r), - filter_s_(start_s), - start_r_(start_r), - start_s_(start_s) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; - - reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] + - (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; - - int offset_n[ThreadMap::Iterations::kStrided]; - int offset_p[ThreadMap::Iterations::kStrided]; - int offset_q[ThreadMap::Iterations::kStrided]; - - int filter_r = filter_r_; - int filter_s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - filter_r = (problem_size_.R - 1 - filter_r); - filter_s = (problem_size_.S - 1 - filter_s); - } - - // Starting h, w positions for filter position in gemm_k=0 - int start_h, start_w; - strided_dgrad_starting_coords( - problem_size_, - stride_h_divmod, stride_w_divmod, - filter_r, filter_s, - start_h, start_w); - - - // Effective starting P and Q for filter position required for remapping NHW rows - int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; - int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - pointer_[s] = reinterpret_cast(ptr); - - int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; - - // (STEP 1) [reorder NHW rows to start with same filter positions] - offset_n[s] = offset_npq / (P * Q); - int residual = offset_npq % (P * Q); - - int p = (residual / Q); - int q = (residual % Q); - - int mapped_h = (start_h + p * problem_size_.stride_h); - int mapped_w = (start_w + q * problem_size_.stride_w); - - // Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0 - // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be - // divisible by stride_h and stride_w - offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; - offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; - - // Initialize pointers for gemm_k=0 - TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; - - pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; - } - - // - // Precompute mask predicates - // - clear_mask(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int p = offset_p[s_idx] ; - - p += (params_.conv_sign * (r / problem_size_.stride_h)); - - bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][0] |= (pred << r); - } - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int q = offset_q[s_idx]; - q += (params_.conv_sign * (s / problem_size_.stride_w)); - - bool pred = (q >=0 && q < problem_size_.Q); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][1] |= (pred << s); - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K); - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}); - } - -private: - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - pointer_[s] += byte_offset - byte_reset; - } - } - -public: - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - add_byte_offset_(pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void advance() { - - int next_idx = 0; - int64_t reset_bytes = 0; - - // Move filter_s by stride_w - filter_s_ += problem_size_.stride_w; - if (filter_s_ >= problem_size_.S) { - - // Restore filter_s - filter_s_ = start_s_; - - // Move filter_r by stride_h - filter_r_ += problem_size_.stride_h; -#if 0 - if (filter_r_ < problem_size_.R) { - - next_idx = 1; - - // Restore bytes in q coordinate (Mma in filter s dimension) - reset_bytes = reset_bytes_s_; - - } else { - - // Restore filter_r - filter_r_ = start_r_; - - next_idx = 2; - - // Restore bytes in p and q coordinate (Mma in filter s and r dimension) - reset_bytes = reset_bytes_r_; - } -#else - asm volatile( - "{\n\t" - " .reg .pred %%p;\n\t" - " setp.lt.s32 %%p, %3, %4;\n\t" - " selp.s32 %0, %3, %5, %%p;\n\t" - " selp.s32 %1, 1, 2, %%p;\n\t" - " selp.s64 %2, %6, %7, %%p;\n\t" - "}\n" - : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) - : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), - "l"(reset_bytes_s_), "l"(reset_bytes_r_)); -#endif - } - - // offset pointers by offset_bytes - add_byte_offset_(params_.inc_next[next_idx] - reset_bytes); - - if (next_idx == 2) { - filter_k_ += params_.filter_k_delta; - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; - masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; - } - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(int v, bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; - masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; - } - } - - /// Returns true if the current coordinate is within the output tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - return - (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - // Limit on filter size - if (problem_size.R > 32 || problem_size.S > 32) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad -// with problem stride = {1x1} -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ -> -class Conv2dDgradOutputGradientTileAccessIteratorOptimized < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kUnity, - AccessType_ -> { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - using Mask = uint64_t; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; - -private: - - Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - - // One pointer per access - char const *pointer_[ThreadMap::Iterations::kStrided]; - - // current filter position (r, s) - int filter_r_; - int filter_s_; - int filter_k_; - - Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorOptimized( - Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - filter_k_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - int offset_n[ThreadMap::Iterations::kStrided]; - int offset_h[ThreadMap::Iterations::kStrided]; - int offset_w[ThreadMap::Iterations::kStrided]; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - pointer_[s] = reinterpret_cast(ptr); - - int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W); - // int residual = offset_nhw % (problem_size_.H * problem_size_.W); - // - // offset_h[s] = residual / problem_size_.W; - // offset_w[s] = residual % problem_size_.W; - // - - int residual; - - params_.hw_divmod(offset_n[s], residual, offset_nhw); - params_.w_divmod(offset_h[s], offset_w[s], residual); - - TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0); - - pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; - } - - clear_mask(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int r = 0; r < problem_size_.R; ++r) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int r_ = r; - if (problem_size_.mode == Mode::kConvolution) { - r_ = problem_size_.R - 1 - r; - } - - int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; - - bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][0] |= (pred << r); - } - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int s = 0; s < problem_size_.S; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int s_ = s; - if (problem_size_.mode == Mode::kConvolution) { - s_ = problem_size_.S - 1 - s; - } - - int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; - - bool pred = (q >= 0 && q < problem_size_.Q); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][1] |= (pred << s); - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); - } - -private: - - /// Returns the coordinate in the output gradient tensor dy that is correspoinding to - // activation nhw and filter position k, r, s - CUTLASS_HOST_DEVICE - TensorCoord at_(int n, int h, int w, int r, int s) const { - - if (problem_size_.mode == Mode::kConvolution) { - r = problem_size_.R - 1 - r; - s = problem_size_.S - 1 - s; - } - - int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; - int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; - - return TensorCoord(n, p, q, filter_k_); - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_byte_offset_(LongIndex byte_offset) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - pointer_[s] += byte_offset; - } - } - -public: - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - add_byte_offset_(pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_HOST_DEVICE - void advance() { - - int next_idx = 0; - - // moves to the next tile - ++filter_s_; - if (filter_s_ == problem_size_.S) { - filter_s_ = 0; - ++filter_r_; - - if (filter_r_ < problem_size_.R) { - next_idx = 1; - } - else { - filter_r_ = 0; - next_idx = 2; - } - } - - add_byte_offset_(params_.inc_next[next_idx]); - - if (next_idx == 2) { - filter_k_ += params_.filter_k_delta; - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; - masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; - } - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(int v, bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; - masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; - } - } - - CUTLASS_HOST_DEVICE - bool valid() { - - return - (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // This is specialized for unit stride - if (problem_size.stride() != MatrixCoord({1, 1})) { - return Status::kErrorNotSupported; - } - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorNotSupported; - } - - // Limit on filter size - if (problem_size.R > 32 || problem_size.S > 32) { - return Status::kErrorNotSupported; - } - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h deleted file mode 100644 index e4eb011e1c675757b9f1fa3111c2de0db658cad5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ /dev/null @@ -1,332 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray, - conv::GroupMode GroupMode_ = conv::GroupMode::kNone -> -class Conv2dFpropActivationTileAccessIteratorAnalytic { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - static conv::GroupMode const kGroupMode = GroupMode_; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_c_; - int filter_r_; - int filter_s_; - int filter_c_init_; - int group_idx_offset_; - int channels_per_group_; - int crs_cnt_; - int crs_per_group_; - - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_p_[ThreadMap::Iterations::kStrided]; - int offset_q_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - crs_cnt_(0), - group_idx_offset_(0), - filter_c_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); - - if (kGroupMode != conv::GroupMode::kNone) { - filter_c_init_ = filter_c_; - channels_per_group_ = problem_size_.C / problem_size_.groups; - crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); - int residual = offset_npq % (problem_size_.P * problem_size_.Q); - - offset_p_[s] = residual / problem_size_.Q; - offset_q_[s] = residual % problem_size_.Q; - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - if (kGroupMode != conv::GroupMode::kNone) { - ++crs_cnt_; - } - - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - - if (kGroupMode == conv::GroupMode::kNone) { - filter_c_ += Shape::kColumn * problem_size_.split_k_slices; - } else { - if (crs_cnt_ == crs_per_group_) { - // moves to next group - crs_cnt_ = 0; - ++group_idx_offset_; - filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; - } else { - filter_c_ += Shape::kColumn * problem_size_.split_k_slices; - } - } - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int n = offset_n_[iteration_strided_]; - int p = offset_p_[iteration_strided_]; - int q = offset_q_[iteration_strided_]; - - int r = filter_r_; - int s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - filter_r_); - s = (problem_size_.S - 1 - filter_s_); - } - - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - int c = filter_c_ + iteration_vector_ * AccessType::kElements; - - return TensorCoord(n, h, w, c); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.C % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.C % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h deleted file mode 100644 index c608ce5305039ce42bd017fd74f14658a6c593da..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h +++ /dev/null @@ -1,360 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dFpropActivationTileAccessIteratorFewChannels { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kPositionsPerTile = Shape::kColumn; - - static int const kAccessesPerVector = kElementsPerAccess / AccessType::kElements; - - static bool const kUseFastDivmodPrologue = true; - static bool const kUseFastDivmodMainloop = true; - - static int const kStrideH = 0; - static int const kStrideW = 0; - static int const kDilationH = 0; - static int const kDilationW = 0; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dFewChannelsParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int rsc_index_; - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_p_[ThreadMap::Iterations::kStrided]; - int offset_q_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorFewChannels( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - rsc_index_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - rsc_index_ = (threadblock_offset.column() + thread_coord.contiguous()); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - if (kUseFastDivmodPrologue) { - int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); - offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); - } - else { - offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); - int residual = offset_npq % (problem_size_.P * problem_size_.Q); - - offset_p_[s] = residual / problem_size_.Q; - offset_q_[s] = residual % problem_size_.Q; - } - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int n = offset_n_[iteration_strided_]; - int p = offset_p_[iteration_strided_]; - int q = offset_q_[iteration_strided_]; - - int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; - - int r = 0; - int s = 0; - int c = 0; - - if (kUseFastDivmodMainloop) { - int rs_index = params_.divmod_C.divmod(c, rsc_index); - r = params_.divmod_S.divmod(s, rs_index); - } - else { - c = (rsc_index % problem_size_.C); - - int rs_index = (rsc_index / problem_size_.C); - s = (rs_index % problem_size_.S); - r = (rs_index / problem_size_.S); - } - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int stride_h = kStrideH; - if (!kStrideH) { - stride_h = problem_size_.stride_h; - } - - int stride_w = kStrideW; - if (!kStrideW) { - stride_w = problem_size_.stride_w; - } - - int dilation_h = kDilationH; - if (!kDilationH) { - dilation_h = problem_size_.dilation_h; - } - - int dilation_w = kDilationW; - if (!kDilationW) { - dilation_w = problem_size_.dilation_w; - } - - int h = p * stride_h - problem_size_.pad_h + r * dilation_h; - int w = q * stride_w - problem_size_.pad_w + s * dilation_w; - - return TensorCoord(n, h, w, c); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - bool in_bounds = - coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; - - return in_bounds; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - - int32_t offset = - coord.n() * params_.stride_n + - coord.h() * params_.stride_h + - coord.w() * params_.stride_w + - coord.c(); - - AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorFewChannels &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (kDilationH && problem_size.dilation_h != kDilationH) { - return Status::kErrorInvalidProblem; - } - - if (kDilationW && problem_size.dilation_w != kDilationW) { - return Status::kErrorInvalidProblem; - } - - if (kStrideH && problem_size.stride_h != kStrideH) { - return Status::kErrorInvalidProblem; - } - - if (kStrideW && problem_size.stride_w != kStrideW) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.C % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.C % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h deleted file mode 100644 index ed0e38c285c78ba570506074f40f6bc5cff45a76..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h +++ /dev/null @@ -1,353 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dFpropActivationTileAccessIteratorFixedChannels { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kFilterPositionsPerTile = Shape::kColumn / AccessType::kElements; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static bool const kUseFastDivmodPrologue = true; - static bool const kUseFastDivmodMainloop = true; - - static int const kStrideH = 0; - static int const kStrideW = 0; - static int const kDilationH = 0; - static int const kDilationW = 0; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dFewChannelsParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int rs_index_; - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_p_[ThreadMap::Iterations::kStrided]; - int offset_q_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorFixedChannels( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - rs_index_(0) { - - // - // This requires problem_size.C == AccessType::kElements - // - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - rs_index_ = (threadblock_offset.column() + thread_coord.contiguous()) / AccessType::kElements; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - if (kUseFastDivmodPrologue) { - int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); - offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); - } - else { - offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); - int residual = offset_npq % (problem_size_.P * problem_size_.Q); - - offset_p_[s] = residual / problem_size_.Q; - offset_q_[s] = residual % problem_size_.Q; - } - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int n = offset_n_[iteration_strided_]; - int p = offset_p_[iteration_strided_]; - int q = offset_q_[iteration_strided_]; - - int rs_index = rs_index_ + iteration_vector_; - - int r = 0; - int s = 0; - - if (kUseFastDivmodMainloop) { - r = params_.divmod_S.divmod(s, rs_index); - } - else { - s = (rs_index % problem_size_.S); - r = (rs_index / problem_size_.S); - } - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int stride_h = kStrideH; - if (!kStrideH) { - stride_h = problem_size_.stride_h; - } - - int stride_w = kStrideW; - if (!kStrideW) { - stride_w = problem_size_.stride_w; - } - - int dilation_h = kDilationH; - if (!kDilationH) { - dilation_h = problem_size_.dilation_h; - } - - int dilation_w = kDilationW; - if (!kDilationW) { - dilation_w = problem_size_.dilation_w; - } - - int h = p * stride_h - problem_size_.pad_h + r * dilation_h; - int w = q * stride_w - problem_size_.pad_w + s * dilation_w; - - return TensorCoord(n, h, w, 0); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - - int32_t offset = - coord.n() * params_.stride_n + - coord.h() * params_.stride_h + - coord.w() * params_.stride_w + coord.c(); - - AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorFixedChannels &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C != AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (kDilationH && problem_size.dilation_h != kDilationH) { - return Status::kErrorInvalidProblem; - } - - if (kDilationW && problem_size.dilation_w != kDilationW) { - return Status::kErrorInvalidProblem; - } - - if (kStrideH && problem_size.stride_h != kStrideH) { - return Status::kErrorInvalidProblem; - } - - if (kStrideW && problem_size.stride_w != kStrideW) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.C % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.C % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h deleted file mode 100644 index 1a5c33e885be7521981e2d4bc5fc35f3b1412ebe..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ /dev/null @@ -1,422 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dFpropActivationTileAccessIteratorOptimized { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - using Mask = uint64_t; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dFpropActivationIteratorOptimizedParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - - // One pointer per access - char const *pointer_[ThreadMap::Iterations::kStrided]; - - // current filter position (r, s) - int filter_r_; - int filter_s_; - int filter_c_; - - Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorOptimized( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - filter_c_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); - - int offset_n[ThreadMap::Iterations::kStrided]; - int offset_p[ThreadMap::Iterations::kStrided]; - int offset_q[ThreadMap::Iterations::kStrided]; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - pointer_[s] = reinterpret_cast(ptr); - - int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q); - // int residual = offset_npq % (problem_size_.P * problem_size_.Q); - // - // offset_p[s] = residual / problem_size_.Q; - // offset_q[s] = residual % problem_size_.Q; - // - - int residual; - - params.pq_divmod(offset_n[s], residual, offset_npq); - params.q_divmod(offset_p[s], offset_q[s], residual); - - TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0); - - pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; - } - - clear_mask(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int r = 0; r < problem_size_.R; ++r) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int r_ = r; - if (problem_size_.mode == Mode::kConvolution) { - r_ = problem_size_.R - 1 - r; - } - - int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; - - bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][0] |= (pred << r); - } - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int s = 0; s < problem_size_.S; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int s_ = s; - if (problem_size_.mode == Mode::kConvolution) { - s_ = problem_size_.S - 1 - s; - } - - int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; - - bool pred = (w >= 0 && w < problem_size_.W); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - masks_[s_idx][v_idx][1] |= (pred << s); - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); - } - -private: - - /// Returns the coordinate in the activations tensor X that is correspoinding to - // output npq and filter position r, s - CUTLASS_HOST_DEVICE - TensorCoord at_(int n, int p, int q, int r, int s) const { - - if (problem_size_.mode == Mode::kConvolution) { - r = problem_size_.R - 1 - r; - s = problem_size_.S - 1 - s; - } - - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, h, w, filter_c_); - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_byte_offset_(LongIndex byte_offset) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - pointer_[s] += byte_offset; - } - } - -public: - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - add_byte_offset_(pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_HOST_DEVICE - void advance() { - - int next_idx = 0; - - // moves to the next tile - ++filter_s_; - if (filter_s_ == problem_size_.S) { - filter_s_ = 0; - ++filter_r_; - - if (filter_r_ < problem_size_.R) { - next_idx = 1; - } - else { - filter_r_ = 0; - next_idx = 2; - } - } - - add_byte_offset_(params_.inc_next[next_idx]); - - if (next_idx == 2) { - filter_c_ += params_.filter_c_delta; - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; - masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; - } - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(int v, bool clear = true) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; - masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; - } - } - - CUTLASS_HOST_DEVICE - bool valid() { - - return - (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { - - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.C % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.C % 64) { - return Status::kErrorInvalidProblem; - } - } - - // Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions - // due to the number of mask bits. - if (problem_size.R > 32 || problem_size.S > 32) { - return Status::kErrorNotSupported; - } - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h deleted file mode 100644 index ed200ed3cf030055b3f7ba470748c91c3751fbfe..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ /dev/null @@ -1,330 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray, - conv::GroupMode GroupMode_ = conv::GroupMode::kNone, - bool IsDeconv_ = false -> -class Conv2dFpropFilterTileAccessIteratorAnalytic { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static bool const IsDeconv = IsDeconv_; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - static conv::GroupMode const kGroupMode = GroupMode_; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_r_; - int filter_s_; - int filter_c_; - int filter_c_init_; - int crs_cnt_; - int crs_per_group_; - int group_idx_offset_c_; - int channels_per_group_; - - int offset_k_[ThreadMap::Iterations::kStrided]; - int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - crs_cnt_(0), - group_idx_offset_c_(0), - filter_r_(0), - filter_s_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); - - auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); - auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); - - if (kGroupMode != conv::GroupMode::kNone) { - filter_c_init_ = filter_c_; - if (kGroupMode == conv::GroupMode::kDepthwise){ - channels_per_group_ = 1; - crs_per_group_ = problem_size_.S * problem_size_.R; - } else { - channels_per_group_ = input_channels / problem_size_.groups; - crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); - } - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { - group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (output_channels / problem_size_.groups); - } - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * 8 / sizeof_bits::value; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - if (kGroupMode != conv::GroupMode::kNone) { - ++crs_cnt_; - } - - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - - if (kGroupMode == conv::GroupMode::kNone) { - filter_c_ += Shape::kRow * problem_size_.split_k_slices; - } else { - if (crs_cnt_ == crs_per_group_) { - crs_cnt_ = 0; - filter_c_ = filter_c_init_; - if (kGroupMode != conv::GroupMode::kDepthwise) { - // moves to next group - ++group_idx_offset_c_; - } - } else { - filter_c_ += Shape::kRow * problem_size_.split_k_slices; - } - } - } - - /// Returns the coordinate in the filter tensor W that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int k = offset_k_[iteration_strided_]; - int c = filter_c_ + iteration_vector_ * AccessType::kElements; - - return TensorCoord(k, filter_r_, filter_s_, c); - } - - /// Returns true if the current coordinate is within the activations tensor W - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); - auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); - - if (kGroupMode == conv::GroupMode::kNone) { - return coord.n() < output_channels && coord.c() < input_channels; - } else if (kGroupMode == conv::GroupMode::kDepthwise) { - return coord.n() < output_channels && coord.c() < 1; // channels_per_group_ is always equal to ONE. - } else { - return coord.n() < output_channels && coord.c() < channels_per_group_ && - group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; - } - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); - auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); - - // check alignment constraint on iterator's contiguous dimension - if ((input_channels / problem_size.groups) % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (output_channels % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (output_channels % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h deleted file mode 100644 index f208c9a5bb2ee697626a8caebc5715073ecdc7eb..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h +++ /dev/null @@ -1,289 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dFpropFilterTileAccessIteratorFewChannels { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kPositionsPerTile = Shape::kRow; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static bool const kUseFastDivmodPrologue = true; - static bool const kUseFastDivmodMainloop = true; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dFewChannelsParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int rsc_index_; - - int offset_k_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorFewChannels( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - rsc_index_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - rsc_index_ = (threadblock_offset.row() + thread_coord.contiguous()); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * 8 / sizeof_bits::value; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the filter tensor W that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; - - int c = 0; - int s = 0; - int r = 0; - - if (kUseFastDivmodMainloop) { - int rs_index = params_.divmod_C.divmod(c, rsc_index); - r = params_.divmod_S.divmod(s, rs_index); - } - else { - c = (rsc_index % problem_size_.C); - int rs_index = (rsc_index / problem_size_.C); - - s = (rs_index % problem_size_.S); - r = (rs_index / problem_size_.S); - } - - int k = offset_k_[iteration_strided_]; - - return TensorCoord(k, r, s, c); - } - - /// Returns true if the current coordinate is within the activations tensor W - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - bool in_bounds = - coord.n() < problem_size_.K && - coord.h() >= 0 && - coord.h() < problem_size_.R && - coord.c() < problem_size_.C; - - return in_bounds; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - - int32_t offset = - coord.n() * params_.stride_n + - coord.h() * params_.stride_h + - coord.w() * params_.stride_w + - coord.c(); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorFewChannels &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.K % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.K % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h deleted file mode 100644 index 2dc2151d8ba2759d55f6602024a5072b31789cf6..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h +++ /dev/null @@ -1,275 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dFpropFilterTileAccessIteratorFixedChannels { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kFilterPositionsPerTile = Shape::kRow / AccessType::kElements; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static bool const kUseFastDivmodPrologue = true; - static bool const kUseFastDivmodMainloop = true; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv2dFewChannelsParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int rs_index_; - - int offset_k_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorFixedChannels( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - rs_index_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - rs_index_ = (threadblock_offset.row() + thread_coord.contiguous()) / AccessType::kElements; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * 8 / sizeof_bits::value; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the filter tensor W that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int rs_index = rs_index_ + iteration_vector_; - - int r = 0; - int s = 0; - - if (kUseFastDivmodMainloop) { - r = params_.divmod_S.divmod(s, rs_index); - } - else { - s = (rs_index % problem_size_.S); - r = (rs_index / problem_size_.S); - } - - int k = offset_k_[iteration_strided_]; - - return TensorCoord(k, r, s, 0); - } - - /// Returns true if the current coordinate is within the activations tensor W - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.K && coord.h() >= 0 && coord.h() < problem_size_.R; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - - int32_t offset = - coord.n() * params_.stride_n + - coord.h() * params_.stride_h + - coord.w() * params_.stride_w + coord.c(); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorFixedChannels &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C != AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (problem_size.K % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (problem_size.K % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h deleted file mode 100644 index 9b12fbe3390c61f9f39ed54ad27cf78e65d80dff..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ /dev/null @@ -1,322 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray, - bool IsDeconv_ = false -> -class Conv2dFpropFilterTileAccessIteratorOptimized{ -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static bool const IsDeconv = IsDeconv_; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - struct Params : Conv2dFpropFilterIteratorOptimizedParams { - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Conv2dFpropFilterIteratorOptimizedParams const &base): - Conv2dFpropFilterIteratorOptimizedParams(base) { } - - CUTLASS_HOST_DEVICE - Params( - Conv2dProblemSize const &problem_size, - Layout const &layout - ): - Conv2dFpropFilterIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ) { - - } - }; - -private: - - Conv2dFpropFilterIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - uint32_t predicates_[kAccessesPerVector]; - int filter_rs_; - int filter_c_; - int channels_per_group_; - - // - // Assertions - // - - // We map predicates into bits packed in this uint32_t container - static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, - "Currently, the number of loads per iteration is limited by the size of the predicates container."); - -public: - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorOptimized( - Conv2dFpropFilterIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_{0}, - filter_rs_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); - Index column = threadblock_offset.column() + thread_coord.strided(); - channels_per_group_ = (IsDeconv ? problem_size_.K : problem_size_.C) / problem_size_.groups; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - predicates_[v_idx] |= (pred << s); - } - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); - } - - pointer_ += ( - params_.layout({filter_c_, column}) - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - LongIndex next = params_.inc_next_rs; - - // moves to the next tile - ++filter_rs_; - if (filter_rs_ == params_.RS) { - - filter_rs_ = 0; - next = params_.inc_next_c; - filter_c_ += params_.filter_c_delta; - } - - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); - } - - pointer_ += next; - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask(int v, bool clear = true) { - predicates_[v] = clear ? 0u : predicates_[v]; - } - - /// Returns true if the current coordinate is within the filter tensor W - CUTLASS_HOST_DEVICE - bool valid() { - return (predicates_[iteration_vector_] & (1u << iteration_strided_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - return reinterpret_cast(pointer_) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - - // Move to the next K coordinate within the tile - pointer_ += params_.inc_next_k; - - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); - auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); - - // check alignment constraint on iterator's contiguous dimension - if ((input_channels / problem_size.groups) % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - if (platform::is_same>::value) { - if (output_channels % 32) { - return Status::kErrorInvalidProblem; - } - } - - if (platform::is_same>::value) { - if (output_channels % 64) { - return Status::kErrorInvalidProblem; - } - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h deleted file mode 100644 index 8a3828fccb00b32d70e215785be3da1d317ed38a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h +++ /dev/null @@ -1,893 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Extracts the host-params objects into non-template code. -*/ - -#pragma once - -#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -#include -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Params structure used for all Conv2d analytic tile iterators -template< typename Layout_ = layout::TensorNHWC > -struct Conv2dAnalyticParams { - - using Layout = Layout_; - - Layout layout; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dAnalyticParams() { } - - CUTLASS_HOST_DEVICE - Conv2dAnalyticParams( - Conv2dProblemSize const &, // unused; placeholder to match other Params interfaces. - Layout const &layout - ): layout(layout) { - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Params structure used for all Conv2d analytic tile iterators -template< typename Layout_ = layout::TensorNHWC > -struct Conv2dFewChannelsParams { - - using Layout = Layout_; - - - int32_t stride_w; - int32_t stride_h; - int32_t stride_n; - - FastDivmod divmod_P; - FastDivmod divmod_Q; - FastDivmod divmod_S; - FastDivmod divmod_C; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dFewChannelsParams() { } - - CUTLASS_HOST_DEVICE - Conv2dFewChannelsParams( - Conv2dProblemSize const &problem_size, // unused; placeholder to match other Params interfaces. - Layout const &layout - ): - stride_w(int32_t(layout.stride()[0])), - stride_h(int32_t(layout.stride()[1])), - stride_n(int32_t(layout.stride()[2])), - divmod_P(problem_size.P), - divmod_Q(problem_size.Q), - divmod_S(problem_size.S), - divmod_C(problem_size.C) - { - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams -struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - int tiled_rows_per_filter; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape - ): layout(layout) { - - int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); - - tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED - -CUTLASS_HOST_DEVICE -void TraceIteratorParams( - char const *conv_operator, - char const *operand, - int element_size_bits, - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta -) { - -#if !defined(__CUDA_ARCH__) - - char const *fname = "conv_iterator_params.csv"; - - std::ifstream test(fname); - bool file_exists = test.is_open(); - - if (file_exists) { - test.close(); - } - - std::ofstream trace("conv_iterator_params.csv", std::ofstream::app); - - if (!file_exists) { - trace - << "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize," - << "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n"; - } - - trace << conv_operator << "," << operand << "," << element_size_bits << "," - << threadblock_shape.row() << "," << threadblock_shape.column() - << "," << thread_count << "," << access_size - << "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided() - << "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n"; -#endif -} - -#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \ - TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta); - -#else - -#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {} - -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -template< typename Layout_ = layout::TensorNHWC > -struct Conv2dFpropActivationIteratorOptimizedParams; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -template<> -struct Conv2dFpropActivationIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - int64_t inc_next[3]; // {next S, next R, next C} - int filter_c_delta; // number of logical elements to add to filter_c_ - int PQ; // product of P*Q - - FastDivmod pq_divmod; - FastDivmod q_divmod; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), - PQ(problem_size.P * problem_size.Q), - pq_divmod(PQ), - q_divmod(problem_size.Q) { - - TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); - - // next S - inc_next[0] = conv_sign * ( - int64_t(layout.stride()[0]) * problem_size.dilation_w - ) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - int64_t(layout.stride()[1]) * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next C - inc_next[2] = ( - threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; - } - -#if ENABLE_CONV2D_PARAMS_PRINT - /// Prints internal state. - CUTLASS_HOST_DEVICE - void print() { - auto stride = layout.stride(); - printf( - "Conv2dFpropActivationIteratorOptimizedParams:\n" - " layout(w: %d, h: %d, n: %d)\n" - " inc_next[%ld, %ld, %ld]\n" - " filter_c_delta(%d) - PQ(%d)\n" - " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" - " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", - stride[0], stride[1], stride[2], - inc_next[0], inc_next[1], inc_next[2], - filter_c_delta, - PQ, - pq_divmod.divisor, - pq_divmod.multiplier, - pq_divmod.shift_right, - q_divmod.divisor, - q_divmod.multiplier, - q_divmod.shift_right - ); - } -#endif -}; - -/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -template -struct Conv2dFpropActivationIteratorOptimizedParams> { - static int const kInterleaved = Interleaved_; - - using Layout = layout::TensorNCxHWx; - - Layout layout; - - int64_t inc_next[3]; // {next S, next R, next C} - int filter_c_delta; // number of logical elements to add to filter_c_ - int PQ; // product of P*Q - - FastDivmod pq_divmod; - FastDivmod q_divmod; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dFpropActivationIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) { - - TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); - - // next S - inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - int64_t(layout.stride()[0]) * problem_size.dilation_h - - (problem_size.S - 1) * kInterleaved * problem_size.dilation_w - ) * element_size_bits / 8; - - // next C - inc_next[2] = ( - threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1]) - - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h - - conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Layout_ = layout::TensorNHWC > -struct Conv2dFpropFilterIteratorOptimizedParams; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Conv2dFpropFilterIteratorOptimizedParams -{ - - using Layout = layout::TensorNHWC; - - Layout layout; - int RS; - int filter_c_delta; - - int64_t inc_next_k; // offset in units of bytes to next K position - int64_t inc_next_rs; // offset in units of bytes to next RS position - int64_t inc_next_c; // offset in units of bytes to next C position - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv2dFpropFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout) { - - TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - RS = problem_size.R * problem_size.S; - - inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8; - - inc_next_rs = - ( int64_t(layout.stride()[0]) - - int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() - ) * element_size_bits / 8; - - inc_next_c = - ( - threadblock_shape.row() * problem_size.split_k_slices - - int64_t(RS - 1) * layout.stride()[0] - - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] - ) * element_size_bits / 8; - - filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; - } - -#if ENABLE_CONV2D_PARAMS_PRINT - /// Prints internal state. - CUTLASS_HOST_DEVICE - void print() { - auto stride = layout.stride(); - printf( - "Conv2dFpropFilterIteratorOptimizedParams:\n" - " layout[%d, %d, %d]\n" - " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", - stride[0], stride[1], stride[2], - RS, - filter_c_delta, - inc_next_k, inc_next_rs, inc_next_c - ); - } -#endif -}; - -template -struct Conv2dFpropFilterIteratorOptimizedParams> -{ - static int const kInterleaved = Interleaved_; - using Layout = layout::TensorCxRSKx; - - Layout layout; - int RS; - int filter_c_delta; - - int64_t inc_next_k; // offset in units of bytes to next K position - int64_t inc_next_rs; // offset in units of bytes to next RS position - int64_t inc_next_c; // offset in units of bytes to next C position - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv2dFpropFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dFpropFilterIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout) { - - TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - RS = problem_size.R * problem_size.S; - - inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8; - - inc_next_rs = - ( int64_t(layout.stride()[0]) - - kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() - ) * element_size_bits / 8; - - inc_next_c = - ( - threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2]) - - int64_t(RS - 1) * layout.stride()[0] - - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved - ) * element_size_bits / 8; - - filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Dgrad Optimized Dy params (layout::TensorNHWC) -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator -struct Conv2dDgradOutputGradientIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - int64_t inc_next[3]; // {next S, next R, next K} - - int filter_k_delta; // number of logical elements to add to filter_k_ - - int HW; // product of H*W - - FastDivmod hw_divmod; - FastDivmod w_divmod; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dDgradOutputGradientIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), - HW(problem_size.H *problem_size.W), - hw_divmod(HW), - w_divmod(problem_size.W) { - - TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); - - // next S - inc_next[0] = conv_sign * ( - (int64_t)layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - (int64_t)layout.stride()[1] * problem_size.dilation_h - - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next K - inc_next[2] = ( - threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h - - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Strided Dgrad Optimized Dy params (layout::TensorNHWC) -///////////////////////////////////////////////////////////////////////////////////////////////// -struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - int64_t inc_next[3]; // {next S, next R, next K} - - int filter_k_delta; // number of logical elements to add to filter_k_ - - int tiled_rows_per_filter; - - int conv_sign; - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dStridedDgradOutputGradientIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape - ): layout(layout) { - - int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); - - tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); - - conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); - - // next S - inc_next[0] = conv_sign * ( - (int64_t)layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - (int64_t)layout.stride()[1] * problem_size.dilation_h - ) * element_size_bits / 8; - - // next K - inc_next[2] = ( - threadblock_shape.column() * problem_size.split_k_slices - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////// -// Dgrad Optimized w params (layout::TensorNHWC) -///////////////////////////////////////////////////////////////////////////////////////////////// -struct Conv2dDgradFilterIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - int RS; - int filter_k_delta; - - int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile - int64_t inc_next_rs; // offset in units of bytes to next RS position - int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv2dDgradFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dDgradFilterIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), RS(problem_size.R * problem_size.S) { - - TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; - - inc_next_rs = - ( (int64_t)layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] - ) * element_size_bits / 8; - - inc_next_k = - ( - threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] - - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] - ) * element_size_bits / 8; - - filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////// -// StridedDgrad Optimized w params (layout::TensorNHWC) -///////////////////////////////////////////////////////////////////////////////////////////////// -struct Conv2dStridedDgradFilterIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - int RS; - int filter_k_delta; - - int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile - int64_t inc_next[3]; // {next S, next R, next K} - int64_t reset_bytes; // offset in units of bytes to move back the pointer - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv2dStridedDgradFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dStridedDgradFilterIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), RS(problem_size.R * problem_size.S) { - - TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; - - // next S - inc_next[0] = - ( (int64_t)layout.stride()[0] * problem_size.stride_w - //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] - ) * element_size_bits / 8; - - // next R - inc_next[1] = - ( (int64_t)layout.stride()[1] * problem_size.stride_h - //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] - ) * element_size_bits / 8; - - // next K - inc_next[2] = - ( - threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] - //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] - //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] - ) * element_size_bits / 8; - - // offset in units of bytes to move the pointer in backward direction - reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] - * element_size_bits / 8; - - filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator -struct Conv2dWgradOutputGradientIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - int NPQ; // precomputd product of N*P*Q for clearing predicates - - FastDivmod pq_divmod; - FastDivmod q_divmod; - - int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile - int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile - int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), - NPQ(problem_size.N * problem_size.P * problem_size.Q), - pq_divmod(problem_size.P * problem_size.Q), - q_divmod(problem_size.Q) { - - TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 - offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) - * element_size_bits / 8; - - offset_next_contiguous = (threadmap_delta.contiguous()) - * element_size_bits / 8; - - inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) - * element_size_bits / 8; - } -}; - -struct Conv2dWgradActivationIteratorOptimizedParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - FastDivmod sc_divmod; - FastDivmod pq_divmod; - FastDivmod q_divmod; - FastDivmod c_divmod; - FastDivmod s_divmod; - int small_channel_conv_s_offset; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv2dWgradActivationIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv2dWgradActivationIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout - ): - layout(layout), - sc_divmod(problem_size.S * problem_size.C), - pq_divmod(problem_size.P * problem_size.Q), - q_divmod(problem_size.Q), - c_divmod(problem_size.C), - s_divmod(problem_size.S * problem_size.dilation_w), - small_channel_conv_s_offset((problem_size.S - 1) * problem_size.dilation_w - problem_size.pad_w) { - } - - CUTLASS_HOST_DEVICE - Conv2dWgradActivationIteratorOptimizedParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - Conv2dWgradActivationIteratorOptimizedParams( - problem_size, - layout - ) { - - TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - } -}; - -struct PredicatedScaleBiasVectorAccessIteratorParams { - public: - /// Default ctor - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIteratorParams() { } - - // Default ctor - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIteratorParams( - Conv2dProblemSize const &problem_size, - layout::PitchLinear const &layout) {} - - // Default ctor - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIteratorParams( - Conv2dProblemSize const &problem_size, - layout::RowMajor const &layout) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h deleted file mode 100644 index 13bd29b7a0547eee204d642a3cd67a24709f89ea..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ /dev/null @@ -1,337 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template wraps the tile access iterator concept to load whole tiles from tensors in - memory used for implicit GEMM convolution. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class TileIterator { -public: - using TileAccessIterator = TileAccessIterator_; - - using Shape = typename TileAccessIterator::Shape; - using Element = typename TileAccessIterator::Element; - using Layout = typename TileAccessIterator::Layout; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = typename TileAccessIterator::ThreadMap; - using AccessType = typename TileAccessIterator::AccessType; - using TensorRef = typename TileAccessIterator::TensorRef; - using Index = typename TileAccessIterator::Index; - using LongIndex = typename TileAccessIterator::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; - using Params = typename TileAccessIterator::Params; - static int const kConvDim = TileAccessIterator::kConvDim; - using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; - static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array< - Element, - ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; - -private: - - /// Internal state - TileAccessIterator tile_access_iterator_; - -public: - - /// Constructor - CUTLASS_HOST_DEVICE - TileIterator( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { } - - CUTLASS_HOST_DEVICE - static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { - return TileAccessIterator::getParams(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - tile_access_iterator_.set_iteration_index(index); - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - tile_access_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - TileIterator &operator++() { - tile_access_iterator_.advance(); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - TileIterator operator++(int) { - TileIterator self(*this); - operator++(); - return self; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - frag.clear(); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[idx], - tile_access_iterator_.get() + pointer_offset, - tile_access_iterator_.valid() - ); - - ++tile_access_iterator_; - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - tile_access_iterator_.set_iteration_index(0); - load_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void advance() { - tile_access_iterator_.advance(); - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - - // dispatch to iterator implementation - return TileAccessIterator::can_implement(problem_size); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Strided Dgrad Tile Iterator -template -class TileIteratorStridedDgrad { -public: - using TileAccessIterator = TileAccessIterator_; - - using Shape = typename TileAccessIterator::Shape; - using Element = typename TileAccessIterator::Element; - using Layout = typename TileAccessIterator::Layout; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = typename TileAccessIterator::ThreadMap; - using AccessType = typename TileAccessIterator::AccessType; - using TensorRef = typename TileAccessIterator::TensorRef; - using Index = typename TileAccessIterator::Index; - using LongIndex = typename TileAccessIterator::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; - static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; - using Params = typename TileAccessIterator::Params; - static int const kConvDim = TileAccessIterator::kConvDim; - using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array< - Element, - ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; - -private: - - /// Internal state - TileAccessIterator tile_access_iterator_; - -public: - - /// Constructor (output gradient (Dy) OperandA ctor) - CUTLASS_HOST_DEVICE - TileIteratorStridedDgrad( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - tile_access_iterator_( - params, - problem_size, - ptr, - thread_idx, - stride_h_divmod, stride_w_divmod, - start_r, start_s, - threadblock_offset) { } - - /// Constructor (filter (w) OperandB ctor) - CUTLASS_HOST_DEVICE - TileIteratorStridedDgrad( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - int start_r, int start_s, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - tile_access_iterator_(params, - problem_size, - ptr, - thread_idx, - start_r, start_s, - threadblock_offset) { } - - CUTLASS_HOST_DEVICE - static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { - return TileAccessIterator::getParams(problem_size, layout); - } - - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - tile_access_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - TileIteratorStridedDgrad &operator++() { - tile_access_iterator_.advance(); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - TileIteratorStridedDgrad operator++(int) { - TileIteratorStridedDgrad self(*this); - operator++(); - return self; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - frag.clear(); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[c + s * ThreadMap::Iterations::kContiguous], - tile_access_iterator_.get() + pointer_offset, - tile_access_iterator_.valid() - ); - - ++tile_access_iterator_; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - tile_access_iterator_.set_iteration_index(0); - load_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void advance() { - tile_access_iterator_.advance(); - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - - // dispatch to iterator implementation - return TileAccessIterator::can_implement(problem_size); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h deleted file mode 100644 index b5a240773b5912c9ace50916f55e8a1054092845..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ /dev/null @@ -1,285 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dWgradActivationTileAccessIteratorAnalytic { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k - int filter_r_[ThreadMap::Iterations::kContiguous]; - int filter_s_[ThreadMap::Iterations::kContiguous]; - int filter_c_[ThreadMap::Iterations::kContiguous]; - - int offset_npq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dWgradActivationTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) - { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize r,s,c filter position for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - - filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); - int residual = rsc_offset % (problem_size_.S * problem_size_.C); - - filter_s_[c] = residual / problem_size_.C; - filter_c_[c] = residual % problem_size_.C; - } - - // initialize n, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the activation tensor x that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int r, s, c; - - if (kAccessesPerVector == 1) { - /// One 128b aligned access fetching more than one element - c = filter_c_[iteration_contiguous_]; - r = filter_r_[iteration_contiguous_]; - s = filter_s_[iteration_contiguous_]; - } - else { - /// Multiple access to support non-128b alignment in contiguous dimension - c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; - int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; - s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; - int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; - r = filter_r_[iteration_contiguous_] + wrap_s; - } - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); - int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); - - int p = residual / problem_size_.Q; - int q = residual % problem_size_.Q; - - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, h, w, c); - } - - /// Returns true if the current coordinate is within the activation tensor x - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h deleted file mode 100644 index 56197279a5a45be6ac992ac16a86011cd9843646..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ /dev/null @@ -1,321 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dWgradActivationTileAccessIteratorOptimized { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - using Params = Conv2dWgradActivationIteratorOptimizedParams; - -private: - - Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k - // required for npq -> nhw translation - int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; - int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; - - // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k - int filter_c_[ThreadMap::Iterations::kContiguous]; - - int offset_npq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dWgradActivationTileAccessIteratorOptimized( - Conv2dWgradActivationIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) - { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize r,s,c filter position for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); - // int residual = rsc_offset % (problem_size_.S * problem_size_.C); - // - // filter_s_[c] = residual / problem_size_.C; - // filter_c_[c] = residual % problem_size_.C; - - int residual; - params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset); - params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual); - - int r = precomputed_filter_r_[c]; - int s = precomputed_filter_s_[c]; - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; - precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; - } - - // initialize n, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the activation tensor x that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int r = precomputed_filter_r_[iteration_contiguous_]; - int s = precomputed_filter_s_[iteration_contiguous_]; - int c = filter_c_[iteration_contiguous_]; - - if (kAccessesPerVector > 1) { - // This code section is only to support non-128b alignment - // Multiple access to support non-128b alignment in contiguous dimension - int wrap_c; - params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); - - if (problem_size_.mode == Mode::kConvolution) { - s -= (problem_size_.dilation_w * wrap_c); - - int wrap_s; - params_.s_divmod(wrap_s, s, params_.small_channel_conv_s_offset - s); - s = params_.small_channel_conv_s_offset - s; - - r -= (problem_size_.dilation_h * wrap_s); - - } else { - s += (problem_size_.dilation_w * wrap_c); - - int wrap_s; - params_.s_divmod(wrap_s, s, s + problem_size_.pad_w); - s -= problem_size_.pad_w; - - r += (problem_size_.dilation_h * wrap_s); - } - } - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); - // int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); - // - // int p = residual / problem_size_.Q; - // int q = residual % problem_size_.Q; - - int residual, n, p, q; - - params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); - params_.q_divmod(p, q, residual); - - int h = p * problem_size_.stride_h + r; - int w = q * problem_size_.stride_w + s; - - return TensorCoord(n, h, w, c); - } - - /// Returns true if the current coordinate is within the activation tensor x - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h deleted file mode 100644 index ea48bc6de0f94015b24632b6609ee8c81dac93cb..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ /dev/null @@ -1,260 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - using Params = Conv2dAnalyticParams; - -private: - - Params const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_k_[ThreadMap::Iterations::kContiguous]; - - int offset_npq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientTileAccessIteratorAnalytic( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize filter_k for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - } - - // initialize n, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_npq_[s] = threadblock_offset.column() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - - } - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int npq = offset_npq_[iteration_strided_]; - - int n = npq / (problem_size_.P * problem_size_.Q); - int residual = npq % (problem_size_.P * problem_size_.Q); - - int p = residual / problem_size_.Q; - int q = residual % problem_size_.Q; - - int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; - - return TensorCoord(n, p, q, k); - } - - - /// Returns true if the current coordinate is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.h() < problem_size_.P && - coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h deleted file mode 100644 index 8e5048fd304f8edd96cca25bc8735725d6f2e843..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ /dev/null @@ -1,310 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray -> -class Conv2dWgradOutputGradientTileAccessIteratorOptimized { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNHWC; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; - -private: - - Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; - Conv2dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - uint32_t predicates_[kAccessesPerVector]; - int filter_k_; - int offset_npq_; - -public: - - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientTileAccessIteratorOptimized( - Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_{0}, - filter_k_(0), - offset_npq_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); - offset_npq_ = threadblock_offset.column() + thread_coord.strided(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; - int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); - - uint32_t pred = (predicate ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_[v] |= (pred << pred_idx); - } - } - } - - // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) - pointer_ += ( - offset_npq_ * params.layout.stride()[0] + filter_k_ - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile - offset_npq_ += Shape::kColumn * problem_size_.split_k_slices; - - // Clear predicates if needed - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - predicates_[v] = (predicates_[v] & (~kClearMask)); - } - } - } - - pointer_ += params_.inc_next_npq; - } - -private: - /// Returns the coordinate in the output gradient tensor Dy that is pointed to - /// by offset_npq and k. - CUTLASS_HOST_DEVICE - TensorCoord at_(int offset_npq, int k) const { - - // The subsequent fast_divmod() operations are equivalent to the following logical computation: - // - // - // int npq = offset_npq; - // int n = npq / (problem_size_.P * problem_size_.Q); - // int residual = npq % (problem_size_.P * problem_size_.Q); - // - // int p = residual / problem_size_.Q; - // int q = residual % problem_size_.Q; - - int residual, n, p, q; - - params_.pq_divmod(n, residual, offset_npq); - params_.q_divmod(p, q, residual); - - return TensorCoord(n, p, q, k); - } - - /// Returns true if the coord is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid_(TensorCoord coord) const { - - return coord.n() < problem_size_.N && - coord.c() < problem_size_.K; - } - -public: - - /// Returns true if the current coordinate is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - - LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_[iteration_vector_] & (1u << pred_idx)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast( - pointer_ + - iteration_strided_ * params_.offset_next_strided + - iteration_contiguous_ * params_.offset_next_contiguous - ) + iteration_vector_; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h deleted file mode 100644 index d996003f42587ef6de2af268e1903808991b34d9..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ /dev/null @@ -1,268 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dDgradFilterTileAccessIteratorAnalytic { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or larger."); - - // - // Parameters structure - // - - struct Params { - - Layout layout; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - } - }; - -private: - - Params const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - // For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension - int filter_t_; - int filter_r_; - int filter_s_; - int offset_k_[ThreadMap::Iterations::kStrided]; - int offset_c_[ThreadMap::Iterations::kContiguous]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dDgradFilterTileAccessIteratorAnalytic( - Params const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_t_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = - threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - ++filter_t_; - if (filter_t_ < problem_size_.T) { - return; - } - filter_t_ = 0; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the filter tensor w that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int c = offset_c_[iteration_contiguous_]; - int k = offset_k_[iteration_strided_]; - - return TensorCoord(k, filter_t_, filter_r_, filter_s_, c); - } - - /// Returns true if the current coordinate is within the filter tensor w - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h deleted file mode 100644 index a269b18b0010329dedd31c4689bff8db4fb46d2a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ /dev/null @@ -1,289 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -> -class Conv3dDgradFilterTileAccessIteratorOptimized { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = StrideSupport_; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - // - // Parameters structure - // - - struct Params : Conv3dDgradFilterIteratorOptimizedParams { - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Conv3dDgradFilterIteratorOptimizedParams const &base): - Conv3dDgradFilterIteratorOptimizedParams(base) { } - - CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): - Conv3dDgradFilterIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ) { } - - }; - -private: - - Conv3dDgradFilterIteratorOptimizedParams const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - uint32_t predicates_; - int filter_trs_; - int filter_k_; - - // - // Assertions - // - - // We map predicates into bits packed in this uint32_t container - static_assert(ThreadMap::Iterations::kStrided * - ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, - "Currently, the number of loads per iteration is limited by the size of the predicates container."); - -public: - - CUTLASS_HOST_DEVICE - Conv3dDgradFilterTileAccessIteratorOptimized( - Conv3dDgradFilterIteratorOptimizedParams const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_(0), - filter_trs_(0), - filter_k_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.row() + thread_coord.strided(); - Index column = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; - int filter_c = column + c * ThreadMap::Delta::kContiguous; - - uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_ |= (pred << pred_idx); - } - } - - pointer_ += ( - filter_k_ * params.layout.stride()[3] + column - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - LongIndex next = params_.inc_next_trs; - - // moves to the next tile - ++filter_trs_; - if (filter_trs_ == params_.TRS) { - - filter_trs_ = 0; - next = params_.inc_next_k; - filter_k_ += params_.filter_k_delta; - } - - // Clear predicates if needed - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - - predicates_ = (predicates_ & (~kClearMask)); - } - } - - pointer_ += next; - } - - /// Returns true if the current coordinate is within the filter tensor W - CUTLASS_HOST_DEVICE - bool valid() { - LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_ & (1u << pred_idx)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - return reinterpret_cast(pointer_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dDgradFilterTileAccessIteratorOptimized &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - - // Move to the next K coordinate within the tile - pointer_ += params_.inc_next_strided; - - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h deleted file mode 100644 index 700c3d12ddfd53b0acef8b6c11188499ca021f76..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ /dev/null @@ -1,343 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided -> -class Conv3dDgradOutputGradientTileAccessIteratorAnalytic; -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -// unscaled coordinations -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < - Shape_, - Element_, - ThreadMap_, - conv::StrideSupport::kStrided -> { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - static_assert(sizeof_bits::value >= 8, - "DGRAD requires elements of size 8b or greater."); - - // - // Simpligying assertions - // - - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - struct Params { - - Layout layout; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - ConvProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - } - }; - -private: - - Params const ¶ms_; - ConvProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - int filter_k_; - int filter_t_; - int filter_r_; - int filter_s_; - - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_d_[ThreadMap::Iterations::kStrided]; - int offset_w_[ThreadMap::Iterations::kStrided]; - int offset_h_[ThreadMap::Iterations::kStrided]; - -private: - - /// Returns the coordinate in the output tensor Dy that is currently pointed to - /// by the iterator but DOES NOT scale by the convolution stride. This is needed - /// to compute predicates in the valid() method. The return value of the public at() - /// method is correctly scaled. - CUTLASS_HOST_DEVICE - TensorCoord unscaled_at_() const { - int n = offset_n_[iteration_strided_]; - int d = offset_d_[iteration_strided_]; - int h = offset_h_[iteration_strided_]; - int w = offset_w_[iteration_strided_]; - - int t = filter_t_; - int r = filter_r_; - int s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - t = (problem_size_.T - 1 - t); - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d); - int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); - int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); - - return TensorCoord(n, z, p, q, filter_k_); - } - -public: - - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientTileAccessIteratorAnalytic( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_k_(0), - filter_t_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); - int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); - - offset_d_[s] = residual / (problem_size_.H * problem_size_.W); - residual = residual % (problem_size_.H * problem_size_.W); - - offset_h_[s] = residual / problem_size_.W; - offset_w_[s] = residual % problem_size_.W; - } - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // move to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - ++filter_t_; - if (filter_t_ < problem_size_.T) { - return; - } - filter_t_ = 0; - - filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the output tensor Dy that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - TensorCoord coord = unscaled_at_(); - - return TensorCoord( - coord.n(), - coord.d() / problem_size_.stride_d, - coord.h() / problem_size_.stride_h, - coord.w() / problem_size_.stride_w, - coord.c()); - } - - - /// Returns true if the current coordinate is within the output tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord unscaled_coord = unscaled_at_(); - TensorCoord coord = at(); - - return - !(unscaled_coord.d() % problem_size_.stride_d) && - !(unscaled_coord.h() % problem_size_.stride_h) && - !(unscaled_coord.w() % problem_size_.stride_w) && - coord.n() < problem_size_.N && - coord.d() >= 0 && coord.d() < problem_size_.Z && - coord.h() >= 0 && coord.h() < problem_size_.P && - coord.w() >= 0 && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h deleted file mode 100644 index 69915babcbfcacc1a1830a4f9d70885aca5d40c8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ /dev/null @@ -1,489 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -> -class Conv3dDgradOutputGradientTileAccessIteratorOptimized { -public: - - static_assert(StrideSupport_ == conv::StrideSupport::kUnity, - "Only unit-stride dgrad is supported at this time."); - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - using Coord3D = Coord<3>; - static int const kAccessesPerVector = 1; - using Mask = uint64_t; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; - -private: - - Params const ¶ms_; - ConvProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - - - // One pointer per access - char const *pointer_[ThreadMap::Iterations::kStrided]; - - // current filter position (t, r, s) - int filter_t_; - int filter_r_; - int filter_s_; - int filter_k_; - - Index masks_[ThreadMap::Iterations::kStrided][3]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientTileAccessIteratorOptimized( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - filter_k_(0), - filter_t_(0), - filter_r_(0), - filter_s_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - int offset_n[ThreadMap::Iterations::kStrided]; - int offset_d[ThreadMap::Iterations::kStrided]; - int offset_h[ThreadMap::Iterations::kStrided]; - int offset_w[ThreadMap::Iterations::kStrided]; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - pointer_[s] = reinterpret_cast(ptr); - - int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // offset_n[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); - // int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); - // - // - // offset_d[s] = residual / (problem_size_.H * problem_size_.W); - // residual = residual % (problem_size_.H * problem_size_.W); - // - // offset_h[s] = residual / problem_size_.W; - // offset_w[s] = residual % problem_size_.W; - // - - int residual; - - // input: (ndhw offset) output: (n offset and resudial (dhw offset)) - params_.dhw_divmod(offset_n[s], residual, offset_ndhw); - // input: (dhw offset) output: (d offset and resudial (hw)) - params_.hw_divmod(offset_d[s], residual, residual); - // input: (hw offset) output: (h offset and resudial (w offset)) - params_.w_divmod(offset_h[s], offset_w[s], residual); - - TensorCoord coord = at_(offset_n[s], offset_d[s], offset_h[s], offset_w[s], 0, 0, 0); - - pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; - } - - clear_mask(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int t = 0; t < problem_size_.T; ++t) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int t_ = t; - if (problem_size_.mode == Mode::kConvolution) { - t_ = problem_size_.T - 1 - t; - } - - int z = offset_d[s_idx] + problem_size_.pad_d - t_ * problem_size_.dilation_d; - - bool pred = (offset_n[s_idx] < problem_size_.N && z >= 0 && z < problem_size_.Z); - masks_[s_idx][0] |= (pred << t); - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int r = 0; r < problem_size_.R; ++r) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int r_ = r; - if (problem_size_.mode == Mode::kConvolution) { - r_ = problem_size_.R - 1 - r; - } - - int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; - - bool pred = (p >= 0 && p < problem_size_.P); - masks_[s_idx][1] |= (pred << r); - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int s = 0; s < problem_size_.S; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int s_ = s; - if (problem_size_.mode == Mode::kConvolution) { - s_ = problem_size_.S - 1 - s; - } - - int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; - - bool pred = (q >= 0 && q < problem_size_.Q); - masks_[s_idx][2] |= (pred << s); - } - } - - if (filter_k_ >= problem_size.K) { - clear_mask(); - } - - set_iteration_index(0); - - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); - } - -private: - - - /// Returns the coordinate in the output gradient tensor dy that is correspoinding to - // activation ndhw and filter position k, t, r, s - CUTLASS_HOST_DEVICE - TensorCoord at_(int n, int d, int h, int w, int t, int r, int s) const { - - if (problem_size_.mode == Mode::kConvolution) { - t = problem_size_.T - 1 - t; - r = problem_size_.R - 1 - r; - s = problem_size_.S - 1 - s; - } - - int z = d + problem_size_.pad_d - t * problem_size_.dilation_d; - int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; - int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; - - return TensorCoord(n, z, p, q, filter_k_); - } - - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_byte_offset_(LongIndex byte_offset) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - pointer_[s] += byte_offset; - } - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask_(bool clear) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - // We are using inline PTX assembly here to avoid an CUDA C++ compilation - // artifact in which control flow instructions are generated. Instead, our - // intent is to predicate the mov instructions. - #if defined(__CUDA_ARCH__) - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][0]) - : - "r"((int)clear), - "r"(masks_[s][0]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][1]) - : - "r"((int)clear), - "r"(masks_[s][1]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][2]) - : - "r"((int)clear), - "r"(masks_[s][2]) - ); - #else - if (clear) { - masks_[s][0] = 0; - masks_[s][1] = 0; - masks_[s][2] = 0; - } - #endif - } - } - -public: - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - add_byte_offset_(pointer_offset * sizeof_bits::value / 8); - } - - - CUTLASS_HOST_DEVICE - void advance() { - - int next_idx = 0; - - // moves to the next tile - ++filter_s_; - if (filter_s_ == problem_size_.S) { - - filter_s_ = 0; - ++filter_r_; - next_idx = 1; - - if (filter_r_ == problem_size_.R) { - filter_r_ = 0; - ++filter_t_; - - if (filter_t_ < problem_size_.T) { - next_idx = 2; - } - else { - filter_t_ = 0; - next_idx = 3; - } - } - } - - add_byte_offset_(params_.inc_next[next_idx]); - - if (next_idx == 3) { - filter_k_ += params_.filter_k_delta; - } - - clear_mask_(filter_k_ >= problem_size_.K); - } - - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask() { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][0] = Mask(0); - masks_[s][1] = Mask(0); - masks_[s][2] = Mask(0); - } - } - - CUTLASS_HOST_DEVICE - bool valid() { - - return - (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && - (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast(pointer_[iteration_strided_]); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientTileAccessIteratorOptimized &operator++() { - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - - // This is specialized for unit stride - if (problem_size.stride() != Coord3D({1, 1, 1})) { - return Status::kErrorNotSupported; - } - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorNotSupported; - } - - // Limit on filter size - if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { - return Status::kErrorNotSupported; - } - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h deleted file mode 100644 index 5a888e0fe4e63a255cd8bdb6b27de831691f71c8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ /dev/null @@ -1,291 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dFpropActivationTileAccessIteratorAnalytic { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv3dAnalyticParams; - -private: - - Params const ¶ms_; - ConvProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - int filter_t_; - int filter_r_; - int filter_s_; - int filter_c_; - - int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_z_[ThreadMap::Iterations::kStrided]; - int offset_p_[ThreadMap::Iterations::kStrided]; - int offset_q_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dFpropActivationTileAccessIteratorAnalytic( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_t_(0), - filter_r_(0), - filter_s_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); - int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); - - offset_z_[s] = residual / (problem_size_.P * problem_size_.Q); - residual = residual % (problem_size_.P * problem_size_.Q); - - offset_p_[s] = residual / problem_size_.Q; - offset_q_[s] = residual % problem_size_.Q; - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - ++filter_t_; - if (filter_t_ < problem_size_.T) { - return; - } - filter_t_ = 0; - - filter_c_ += Shape::kColumn * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int n = offset_n_[iteration_strided_]; - int z = offset_z_[iteration_strided_]; - int p = offset_p_[iteration_strided_]; - int q = offset_q_[iteration_strided_]; - - int t = filter_t_; - int r = filter_r_; - int s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - t = (problem_size_.T - 1 - filter_t_); - r = (problem_size_.R - 1 - filter_r_); - s = (problem_size_.S - 1 - filter_s_); - } - - int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, d, h, w, filter_c_); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.d() >= 0 && coord.d() < problem_size_.D && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h deleted file mode 100644 index 057023c09cb73199bda94f62c3c879269d7b5189..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ /dev/null @@ -1,478 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_ -> -class Conv3dFpropActivationTileAccessIteratorOptimized { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - using Mask = uint64_t; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv3dFpropActivationIteratorOptimizedParams; - -private: - - Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - - // One pointer per access - char const *pointer_[ThreadMap::Iterations::kStrided]; - - // current filter position (t, r, s) - int filter_t_; - int filter_r_; - int filter_s_; - int filter_c_; - - // mask for t, r, and s - Index masks_[ThreadMap::Iterations::kStrided][3]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dFpropActivationTileAccessIteratorOptimized( - Conv3dFpropActivationIteratorOptimizedParams const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles - ) : - params_(params), - problem_size_(problem_size), - filter_t_(0), - filter_r_(0), - filter_s_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); - - int offset_n[ThreadMap::Iterations::kStrided]; - int offset_z[ThreadMap::Iterations::kStrided]; - int offset_p[ThreadMap::Iterations::kStrided]; - int offset_q[ThreadMap::Iterations::kStrided]; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - pointer_[s] = reinterpret_cast(ptr); - - int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // offset_n[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); - // int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); - // - // offset_z[s] = residual / (problem_size_.P * problem_size_.Q); - // residual = residual % (problem_size_.P * problem_size_.Q); - // - // offset_p[s] = residual / problem_size_.Q; - // offset_q[s] = residual % problem_size_.Q; - // - - int residual; - - // input: (nzpq offset) output: (n offset and resudial (zpq offset)) - params.zpq_divmod(offset_n[s], residual, offset_nzpq); - // input: (zpq offset) output: (z offset and resudial (pq)) - params.pq_divmod(offset_z[s], residual, residual); - // input: (pq offset) output: (p offset and resudial (q offset)) - params.q_divmod(offset_p[s], offset_q[s], residual); - - TensorCoord coord = at_(offset_n[s], offset_z[s], offset_p[s], offset_q[s], 0, 0, 0); - - pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; - } - - clear_mask(); - - // mask predicates for filter position T - CUTLASS_PRAGMA_NO_UNROLL - for (int t = 0; t < problem_size_.T; ++t) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int t_ = t; - if (problem_size_.mode == Mode::kConvolution) { - t_ = problem_size_.T - 1 - t; - } - - int d = offset_z[s_idx] * problem_size_.stride_d - problem_size_.pad_d + t_ * problem_size_.dilation_d; - - bool pred = (offset_n[s_idx] < problem_size_.N && d >= 0 && d < problem_size_.D); - masks_[s_idx][0] |= (pred << t); - } - } - - // mask predicates for filter position R - CUTLASS_PRAGMA_NO_UNROLL - for (int r = 0; r < problem_size_.R; ++r) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int r_ = r; - if (problem_size_.mode == Mode::kConvolution) { - r_ = problem_size_.R - 1 - r; - } - - int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; - - bool pred = (h >= 0 && h < problem_size_.H); - masks_[s_idx][1] |= (pred << r); - } - } - - // mask predicates for filter position S - CUTLASS_PRAGMA_NO_UNROLL - for (int s = 0; s < problem_size_.S; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { - - int s_ = s; - if (problem_size_.mode == Mode::kConvolution) { - s_ = problem_size_.S - 1 - s; - } - - int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; - - bool pred = (w >= 0 && w < problem_size_.W); - masks_[s_idx][2] |= (pred << s); - } - } - - if (filter_c_ >= problem_size.C) { - clear_mask(); - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); - } - -private: - - /// Returns the coordinate in the activations tensor X that is correspoinding to - // output nzpq and filter position t, r, s - CUTLASS_HOST_DEVICE - TensorCoord at_(int n, int z, int p, int q, int t, int r, int s) const { - - if (problem_size_.mode == Mode::kConvolution) { - t = problem_size_.T - 1 - t; - r = problem_size_.R - 1 - r; - s = problem_size_.S - 1 - s; - } - - int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, d, h, w, filter_c_); - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_byte_offset_(LongIndex byte_offset) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - pointer_[s] += byte_offset; - } - } - - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask_(bool clear) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - // We are using inline PTX assembly here to avoid an CUDA C++ compilation - // artifact in which control flow instructions are generated. Instead, our - // intent is to predicate the mov instructions. - #if defined(__CUDA_ARCH__) - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][0]) - : - "r"((int)clear), - "r"(masks_[s][0]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][1]) - : - "r"((int)clear), - "r"(masks_[s][1]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][2]) - : - "r"((int)clear), - "r"(masks_[s][2]) - ); - #else - if (clear) { - masks_[s][0] = 0; - masks_[s][1] = 0; - masks_[s][2] = 0; - } - #endif - } - } - -public: - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - add_byte_offset_(pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_HOST_DEVICE - void advance() { - - int next_idx = 0; - - // moves to the next tile - ++filter_s_; - if (filter_s_ == problem_size_.S) { - - filter_s_ = 0; - ++filter_r_; - next_idx = 1; - - if (filter_r_ == problem_size_.R) { - filter_r_ = 0; - ++filter_t_; - - if (filter_t_ < problem_size_.T) { - next_idx = 2; - } - else { - filter_t_ = 0; - next_idx = 3; - } - } - } - - add_byte_offset_(params_.inc_next[next_idx]); - - if (next_idx == 3) { - filter_c_ += params_.filter_c_delta; - } - - clear_mask_(filter_c_ >= problem_size_.C); - } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask() { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][0] = Mask(0); - masks_[s][1] = Mask(0); - masks_[s][2] = Mask(0); - } - } - - CUTLASS_HOST_DEVICE - bool valid() { - - return - (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && - (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast(pointer_[iteration_strided_]); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dFpropActivationTileAccessIteratorOptimized &operator++() { - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - // Conv3dFpropActivationTileAccessIteratorOptimized has constraint on filter positions - // due to the number of mask bits. - if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { - return Status::kErrorNotSupported; - } - return Status::kSuccess; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h deleted file mode 100644 index 4a40d37e56bfc73744f80dc0cf84e30918b86a1b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ /dev/null @@ -1,259 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_, - bool IsDeconv_ = false -> -class Conv3dFpropFilterTileAccessIteratorAnalytic { -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static bool const IsDeconv = IsDeconv_; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - using Params = Conv3dAnalyticParams; - -private: - - Params const ¶ms_; - ConvProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - int filter_t_; - int filter_r_; - int filter_s_; - int filter_c_; - - int offset_k_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dFpropFilterTileAccessIteratorAnalytic( - Params const ¶ms, - ConvProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_t_(0), - filter_r_(0), - filter_s_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * 8 / sizeof_bits::value; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - ++filter_s_; - if (filter_s_ < problem_size_.S) { - return; - } - filter_s_ = 0; - - ++filter_r_; - if (filter_r_ < problem_size_.R) { - return; - } - filter_r_ = 0; - - ++filter_t_; - if (filter_t_ < problem_size_.T) { - return; - } - filter_t_ = 0; - - filter_c_ += Shape::kRow * problem_size_.split_k_slices; - } - - /// Returns the coordinate in the filter tensor W that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int k = offset_k_[iteration_strided_]; - - return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_); - } - - /// Returns true if the current coordinate is within the activations tensor W - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); - auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); - - return coord.n() < output_channels && - coord.c() < input_channels; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(ConvProblemSize const &problem_size) { - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); - auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); - // check alignment constraint on iterator's contiguous dimension - if (input_channels % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h deleted file mode 100644 index b4e7db3a4398b67a2a2cf185cf9e689a22d3d0b8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ /dev/null @@ -1,279 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename Layout_, - typename ThreadMap_, - bool IsDeconv_ = false -> -class Conv3dFpropFilterTileAccessIteratorOptimized{ -public: - - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static bool const IsDeconv = IsDeconv_; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - - struct Params : Conv3dFpropFilterIteratorOptimizedParams { - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Conv3dFpropFilterIteratorOptimizedParams const &base): - Conv3dFpropFilterIteratorOptimizedParams(base) { } - - CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): - Conv3dFpropFilterIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ) { - - } - }; - -private: - - Conv3dFpropFilterIteratorOptimizedParams const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - uint32_t predicates_; - int filter_trs_; - int filter_c_; - - // - // Assertions - // - - // We map predicates into bits packed in this uint32_t container - static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, - "Currently, the number of loads per iteration is limited by the size of the predicates container."); - -public: - - CUTLASS_HOST_DEVICE - Conv3dFpropFilterTileAccessIteratorOptimized( - Conv3dFpropFilterIteratorOptimizedParams const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_{0}, - filter_trs_(0), - filter_c_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); - Index column = threadblock_offset.column() + thread_coord.strided(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); - predicates_ |= (pred << s); - } - - if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { - predicates_ = 0u; - } - - pointer_ += ( - params_.layout({filter_c_, column}) - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - LongIndex next = params_.inc_next_trs; - - // moves to the next tile - ++filter_trs_; - if (filter_trs_ == params_.TRS) { - - filter_trs_ = 0; - next = params_.inc_next_c; - filter_c_ += params_.filter_c_delta; - } - - if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { - predicates_ = 0; - } - - pointer_ += next; - } - - /// Returns true if the current coordinate is within the filter tensor W - CUTLASS_HOST_DEVICE - bool valid() { - return (predicates_ & (1u << iteration_strided_)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - return reinterpret_cast(pointer_); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dFpropFilterTileAccessIteratorOptimized &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - - // Move to the next K coordinate within the tile - pointer_ += params_.inc_next_k; - - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); - - // check alignment constraint on iterator's contiguous dimension - if (input_channels % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h deleted file mode 100644 index 941f4e1dff7ebfffd6830ad76876087b98a0c8b0..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h +++ /dev/null @@ -1,508 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Extracts the host-params objects into non-template code. -*/ - -#pragma once - -#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/threadblock/conv2d_params.h" -#include "cutlass/conv/conv3d_problem_size.h" - -#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -#include -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Params structure used for all Conv3d analytic tile iterators -template< typename Layout_ = layout::TensorNDHWC > -struct Conv3dAnalyticParams { - - using Layout = Layout_; - - Layout layout; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv3dAnalyticParams() { } - - CUTLASS_HOST_DEVICE - Conv3dAnalyticParams( - Conv3dProblemSize const &, // unused; placeholder to match other Params interfaces. - Layout const &layout - ): layout(layout) { - - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -template< typename Layout_ = layout::TensorNDHWC > -struct Conv3dFpropActivationIteratorOptimizedParams; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -template<> -struct Conv3dFpropActivationIteratorOptimizedParams { - - using Layout = layout::TensorNDHWC; - - Layout layout; - - int64_t inc_next[4]; // {next S, next R, next T, next C} - int filter_c_delta; // number of logical elements to add to filter_c_ - int ZPQ; // product of Z*P*Q - int PQ; // product of P*Q - - FastDivmod zpq_divmod; - FastDivmod pq_divmod; - FastDivmod q_divmod; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv3dFpropActivationIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dFpropActivationIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), - PQ(problem_size.P * problem_size.Q), - ZPQ(problem_size.Z * problem_size.P * problem_size.Q), - zpq_divmod(ZPQ), - pq_divmod(PQ), - q_divmod(problem_size.Q) { - - TRACE_CONV_INITIALIZERS("conv3d_fprop", "activation", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - - int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); - - // next S - inc_next[0] = conv_sign * ( - int64_t(layout.stride()[0]) * problem_size.dilation_w - ) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - int64_t(layout.stride()[1]) * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next T - inc_next[2] = conv_sign * ( - int64_t(layout.stride()[2]) * problem_size.dilation_d - - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next C - inc_next[3] = ( - threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d - - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - - -template< typename Layout_ = layout::TensorNDHWC > -struct Conv3dFpropFilterIteratorOptimizedParams; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Conv3dFpropFilterIteratorOptimizedParams -{ - - using Layout = layout::TensorNDHWC; - - Layout layout; - int TRS; - int filter_c_delta; - - int64_t inc_next_k; // offset in units of bytes to next K position - int64_t inc_next_trs; // offset in units of bytes to next TRS position - int64_t inc_next_c; // offset in units of bytes to next C position - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv3dFpropFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dFpropFilterIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout) { - - TRACE_CONV_INITIALIZERS("conv3d_fprop", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - TRS = problem_size.T * problem_size.R * problem_size.S; - - inc_next_k = (int64_t(layout.stride()[3]) * threadmap_delta.strided() * element_size_bits) / 8; - - inc_next_trs = - ( int64_t(layout.stride()[0]) - - int64_t(layout.stride()[3]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() - ) * element_size_bits / 8; - - inc_next_c = - ( - threadblock_shape.row() * problem_size.split_k_slices - - int64_t(TRS - 1) * layout.stride()[0] - - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] - ) * element_size_bits / 8; - - filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters object for Conv3d DGRAD OutputGradient (dy) iterator -struct Conv3dDgradOutputGradientIteratorOptimizedParams { - - using Layout = layout::TensorNDHWC; - - Layout layout; - - int64_t inc_next[4]; // {next S, next R, next T, next K} - int filter_k_delta; // number of logical elements to add to filter_k_ - - FastDivmod dhw_divmod; - FastDivmod hw_divmod; - FastDivmod w_divmod; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dDgradOutputGradientIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, ///< layout object - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), - dhw_divmod(problem_size.D * problem_size.H * problem_size.W), - hw_divmod(problem_size.H * problem_size.W), - w_divmod(problem_size.W) { - - TRACE_CONV_INITIALIZERS("conv3d_dgrad", "output_gradient", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); - - // next S - inc_next[0] = conv_sign * ( - int64_t(layout.stride()[0]) * problem_size.dilation_w - ) * element_size_bits / 8; - - // next R - inc_next[1] = conv_sign * ( - int64_t(layout.stride()[1]) * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next T - inc_next[2] = conv_sign * ( - int64_t(layout.stride()[2]) * problem_size.dilation_d - - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // next K - inc_next[3] = ( - threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d - - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w - ) * element_size_bits / 8; - - // logical offset added to internal channel counter - units are elements, not bytes - filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters object for Conv2d DGRAD Filter (w) iterator -struct Conv3dDgradFilterIteratorOptimizedParams { - - using Layout = layout::TensorNDHWC; - - Layout layout; - int TRS; - int filter_k_delta; - - int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile - int64_t inc_next_trs; // offset in units of bytes to next TRS position - int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv3dDgradFilterIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dDgradFilterIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, ///< size of each element in bits - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): - layout(layout), TRS(problem_size.T * problem_size.R * problem_size.S) { - - TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; - - inc_next_trs = - ( (int64_t)layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] - ) * element_size_bits / 8; - - inc_next_k = - ( - threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] - - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] - ) * element_size_bits / 8; - - filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; - } -}; - -/// Parameters object for Conv3d WGRAD OutputGradient iterator -struct Conv3dWgradOutputGradientIteratorOptimizedParams { - - using Layout = layout::TensorNDHWC; - using LongIndex = typename Layout::LongIndex; - - Layout layout; - - int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates - int ZPQ; // product of Z*P*Q - unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ - unsigned zpq_shr; // in device code. - - int PQ; // product of P*Q - unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ - unsigned pq_shr; // in device code. - - unsigned q_mul; // precomputed quantities for fast computation of div/% by Q - unsigned q_shr; // in device code. - - LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile - LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile - LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): layout(layout) { - - TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 - offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) - * element_size_bits / 8; - - offset_next_contiguous = (threadmap_delta.contiguous()) - * element_size_bits / 8; - - inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) - * element_size_bits / 8; - - // Precompute several quantities for fast modulo arithmetic. - NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; - ZPQ = problem_size.Z * problem_size.P * problem_size.Q; - find_divisor(zpq_mul, zpq_shr, ZPQ); - - PQ = problem_size.P * problem_size.Q; - find_divisor(pq_mul, pq_shr, PQ); - - find_divisor(q_mul, q_shr, problem_size.Q); - - } -}; - -/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator -struct Conv3dWgradActivationIteratorOptimizedParams { - - using Layout = layout::TensorNDHWC; - - Layout layout; - - int RSC; // product of R*S*C - unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC - unsigned rsc_shr; // in device code. - - int SC; // product of S*C - unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC - unsigned sc_shr; // in device code. - - unsigned c_mul; // precomputed quantities for fast computation of div/% by C - unsigned c_shr; // in device code. - - int ZPQ; // product of Z*P*Q - unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ - unsigned zpq_shr; // in device code. - - int PQ; // product of P*Q - unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ - unsigned pq_shr; // in device code. - - unsigned q_mul; // precomputed quantities for fast computation of div/% by Q - unsigned q_shr; // in device code. - - // - // Methods - // - CUTLASS_HOST_DEVICE - Conv3dWgradActivationIteratorOptimizedParams() { } - - CUTLASS_HOST_DEVICE - Conv3dWgradActivationIteratorOptimizedParams( - Conv3dProblemSize const &problem_size, - Layout const &layout, - int element_size_bits, - MatrixCoord threadblock_shape, - int thread_count, - int access_size, - layout::PitchLinearCoord threadmap_iterations, - layout::PitchLinearCoord threadmap_delta - ): layout(layout) { - - TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", - element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - - // Precompute several quantities for fast modulo arithmetic. - RSC = problem_size.R * problem_size.S * problem_size.C; - find_divisor(rsc_mul, rsc_shr, RSC); - - SC = problem_size.S * problem_size.C; - find_divisor(sc_mul, sc_shr, SC); - - find_divisor(c_mul, c_shr, problem_size.C); - - ZPQ = problem_size.Z * problem_size.P * problem_size.Q; - find_divisor(zpq_mul, zpq_shr, ZPQ); - - PQ = problem_size.P * problem_size.Q; - find_divisor(pq_mul, pq_shr, PQ); - - find_divisor(q_mul, q_shr, problem_size.Q); - - } -}; - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h deleted file mode 100644 index 97cad0a131667235fbab4c7dd092c1571ae3ee6c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ /dev/null @@ -1,289 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dWgradActivationTileAccessIteratorAnalytic { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - - static int const kAccessesPerVector = 1; - - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - struct Params { - - Layout layout; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - } - }; - -private: - - Params const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - // Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k - int filter_t_[ThreadMap::Iterations::kContiguous]; - int filter_r_[ThreadMap::Iterations::kContiguous]; - int filter_s_[ThreadMap::Iterations::kContiguous]; - int filter_c_[ThreadMap::Iterations::kContiguous]; - - int offset_nzpq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dWgradActivationTileAccessIteratorAnalytic( - Params const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize t,r,s,c filter position for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - - filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); - int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); - - filter_r_[c] = residual / (problem_size_.S * problem_size_.C); - residual = residual % (problem_size_.S * problem_size_.C); - - filter_s_[c] = residual / problem_size_.C; - filter_c_[c] = residual % problem_size_.C; - - } - - // initialize n, z, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the activation tensor x that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int t = filter_t_[iteration_contiguous_]; - int r = filter_r_[iteration_contiguous_]; - int s = filter_s_[iteration_contiguous_]; - - if (problem_size_.mode == Mode::kConvolution) { - t = (problem_size_.T - 1 - t); - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); - int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); - - int z = residual / (problem_size_.P * problem_size_.Q); - residual = residual % (problem_size_.P * problem_size_.Q); - - int p = residual / problem_size_.Q; - int q = residual % problem_size_.Q; - - int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; - int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; - int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); - } - - /// Returns true if the current coordinate is within the activation tensor x - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.d() >= 0 && coord.d() < problem_size_.D && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h deleted file mode 100644 index 7e5475f8f738e4f434f72e1f50a2c5762904cc42..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ /dev/null @@ -1,319 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dWgradActivationTileAccessIteratorOptimized { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - struct Params : Conv3dWgradActivationIteratorOptimizedParams { - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Conv3dWgradActivationIteratorOptimizedParams const &base) - : Conv3dWgradActivationIteratorOptimizedParams(base) {} - - CUTLASS_HOST_DEVICE - Params(Conv3dProblemSize const &problem_size, Layout const &layout) - : Conv3dWgradActivationIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} - }; - -private: - - Params const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - // Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k - // required for nzpq -> ndhw translation - int precomputed_filter_t_[ThreadMap::Iterations::kContiguous]; - int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; - int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; - - // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k - int filter_c_[ThreadMap::Iterations::kContiguous]; - - int offset_nzpq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dWgradActivationTileAccessIteratorOptimized( - Params const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize t,r,s,c filter position for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); - // int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); - // - // filter_r_[c] = residual / (problem_size_.S * problem_size_.C); - // residual = residual % (problem_size_.S * problem_size_.C); - // - // filter_s_[c] = residual / problem_size_.C; - // filter_c_[c] = residual % problem_size_.C; - - int residual; - fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr); - fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr); - fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr); - - int t = precomputed_filter_t_[c]; - int r = precomputed_filter_r_[c]; - int s = precomputed_filter_s_[c]; - - if (problem_size_.mode == Mode::kConvolution) { - t = (problem_size_.T - 1 - t); - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - // efective t,r,s for every contiguous dimension - precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d; - precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; - precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; - - - } - - // initialize n, z, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - } - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - - // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the activation tensor x that is currently pointed to - /// by the iterator. - - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); - // int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); - // - // int z = residual / (problem_size_.P * problem_size_.Q); - // residual = residual % (problem_size_.P * problem_size_.Q); - // - // int p = residual / problem_size_.Q; - // int q = residual % problem_size_.Q; - - int residual, n, z, p, q; - fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr); - fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); - fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); - - int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; - int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_]; - int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; - - return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); - } - - /// Returns true if the current coordinate is within the activation tensor x - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.d() >= 0 && coord.d() < problem_size_.D && - coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dWgradActivationTileAccessIteratorOptimized &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h deleted file mode 100644 index cbe49985f5df8b76bfd1e57552e47577c379f229..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ /dev/null @@ -1,267 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - struct Params { - - Layout layout; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - } - }; - -private: - - Params const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - int filter_k_[ThreadMap::Iterations::kContiguous]; - - int offset_nzpq_[ThreadMap::Iterations::kStrided]; - -public: - - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientTileAccessIteratorAnalytic( - Params const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)) { - - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - // initialize filter_k for every contiguous iteration - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() - + c * ThreadMap::Delta::kContiguous; - } - - // initialize n, p, q offset for every strided iteration - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided() - + s * ThreadMap::Delta::kStrided; - - } - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices; - } - } - - /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int nzpq = offset_nzpq_[iteration_strided_]; - - int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); - int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); - - int z = residual / (problem_size_.P * problem_size_.Q); - residual = residual % (problem_size_.P * problem_size_.Q); - - int p = residual / problem_size_.Q; - int q = residual % problem_size_.Q; - - return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]); - } - - - /// Returns true if the current coordinate is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && - coord.d() < problem_size_.Z && - coord.h() < problem_size_.P && - coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h deleted file mode 100644 index 6c2f2e51e5e69f28d552839f906237b35d4879db..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ /dev/null @@ -1,310 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) - matrix from memory. - - This iterator assumes TensorNDHWC layout of tensors in Global Memory. - - The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), - backward data gradient (Dgrad), and backward weight gradient (Wgrad). -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/conv/threadblock/conv3d_params.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, - typename Element_, - typename ThreadMap_ -> -class Conv3dWgradOutputGradientTileAccessIteratorOptimized { -public: - - // - // Types - // - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorNDHWC; - using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 3; - using ConvProblemSize = typename conv::Conv3dProblemSize; - static int const kAccessesPerVector = 1; - static_assert(sizeof_bits::value >= 8, - "WGRAD requires elements of size 8b or greater."); - - // - // Parameters structure - // - - struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) - : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} - - CUTLASS_HOST_DEVICE - Params(Conv3dProblemSize const &problem_size, Layout const &layout) - : Conv3dWgradOutputGradientIteratorOptimizedParams( - problem_size, - layout, - sizeof_bits::value, - {Shape::kRow, Shape::kColumn}, - ThreadMap::kThreads, - ThreadMap::kElementsPerAccess, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} - }; - -private: - - Params const ¶ms_; - Conv3dProblemSize const &problem_size_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - char const *pointer_; - - uint32_t predicates_; - int filter_k_; - int offset_nzpq_; - -public: - - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientTileAccessIteratorOptimized( - Params const ¶ms, - Conv3dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - predicates_(0), - filter_k_(0), - offset_nzpq_(0) { - - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); - offset_nzpq_ = threadblock_offset.column() + thread_coord.strided(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; - int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided; - - bool predicate = valid_(at_(offset_nzpq, filter_k)); - - uint32_t pred = (predicate ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_ |= (pred << pred_idx); - } - } - - // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) - pointer_ += ( - offset_nzpq_ * params.layout.stride()[0] + filter_k_ - ) * sizeof_bits::value / 8; - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile - offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices; - - // Clear predicates if needed - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - predicates_ = (predicates_ & (~kClearMask)); - } - } - pointer_ += params_.inc_next_nzpq; - } - -private: - /// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at_(int offset_nzpq, int k) const { - - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: - // - // - // int nzpq = offset_nzpq_; - // int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); - // int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); - // - // int z = residual / (problem_size_.P * problem_size_.Q); - // residual = residual % (problem_size_.P * problem_size_.Q); - // - // int p = residual / problem_size_.Q; - // int q = residual % problem_size_.Q; - - int residual, n, z, p, q; - fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr); - fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); - fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); - - return TensorCoord(n, z, p, q, k); - } - - /// Returns true if the coord is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid_(TensorCoord coord) const { - - return coord.n() < problem_size_.N && - coord.c() < problem_size_.K; - } - -public: - - /// Returns true if the current coordinate is within the output gradient tensor Dy - CUTLASS_HOST_DEVICE - bool valid() const { - - LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_ & (1u << pred_idx)); - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - - return reinterpret_cast( - pointer_ + - iteration_strided_ * params_.offset_next_strided + - iteration_contiguous_ * params_.offset_next_contiguous - ); - - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() { - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv3dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h deleted file mode 100644 index f5cd2a740232c257f8e3b25c37408973f536722b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h +++ /dev/null @@ -1,230 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Extracts the host-params objects into non-template code. -*/ - -#pragma once - -#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -#include -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -template -struct Depthwise2dFpropDirectConvParams; - -/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -template -struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; - -/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -template -struct Depthwise2dFpropDirectConvFilterIteratorParams; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -template<> -struct Depthwise2dFpropDirectConvParams { - - using Layout = layout::TensorNHWC; - - Layout layout; - - int32_t activation_tile_h; - int32_t activation_tile_w; - int32_t activation_tile_hw; - FastDivmod activation_tile_w_divmod; - - int filter[2]; - int stride[2]; - int dilation[2]; - int inc_next[2]; - FastDivmod pq_divmod; - FastDivmod q_divmod; - - int activation_load_count; - int activation_storage_elements; - int activation_size; - // - // Methods - // - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvParams() { } - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< layout object - MatrixCoord threadblock_shape, ///< CTA threadblock Shape - Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock - const int element_size_bits, ///< bits of activation element - const int thread_count, ///< threads per threadblock - const int thread_count_contiguous, ///< number of threads for continuous dimension - const int element_per_load) ///< element per each load - : layout(layout) { - - filter[0] = problem_size.S; - filter[1] = problem_size.R; - - stride[0] = problem_size.stride_w; - stride[1] = problem_size.stride_h; - - dilation[0] = problem_size.dilation_w; - dilation[1] = problem_size.dilation_h; - - // Compute activation_tile size per threadblock because stride and dilation are runtime params. - activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + - (problem_size.R - 1) * problem_size.dilation_h + 1; - activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + - (problem_size.S - 1) * problem_size.dilation_w + 1; - activation_tile_hw = activation_tile_h * activation_tile_w; - - activation_tile_w_divmod = FastDivmod(activation_tile_w); - - /// Below two values could not be templatized because the stride and dilation are runtime params - activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; - activation_storage_elements = activation_load_count * element_per_load * thread_count; - activation_size = activation_storage_elements * element_size_bits / 8; - - // Fastdivmod for output P, Q - int tiles_p = - (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); - int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / - (threadblock_output_shape.w()); - - pq_divmod = FastDivmod(tiles_p * tiles_q); - q_divmod = FastDivmod(tiles_q); - - // next S - inc_next[0] = problem_size.dilation_w; - // next R - inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -template <> -struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { - using Layout = layout::TensorNHWC; - - Layout layout; - - FastDivmod pq_divmod; - FastDivmod q_divmod; - - int activation_size; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< Layout object - MatrixCoord threadblock_shape, ///< Threadblock Shape - Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock - const int activation_size_ ///< Activation size loaded by iterator - ) - : layout(layout), - activation_size(activation_size_) { - // Fastdivmod for output P, Q - int tiles_p = - (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); - int tiles_q = - (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); - - pq_divmod = FastDivmod(tiles_p * tiles_q); - q_divmod = FastDivmod(tiles_q); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -template <> -struct Depthwise2dFpropDirectConvFilterIteratorParams { - using Layout = layout::TensorNHWC; - - Layout layout; - - int filter_size; - - bool is_convolution; - // - // Methods - // - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvFilterIteratorParams() {} - - CUTLASS_HOST_DEVICE - Depthwise2dFpropDirectConvFilterIteratorParams( - Conv2dProblemSize const &problem_size, - Layout const &layout, ///< Layout object - MatrixCoord threadblock_shape, ///< Threadblock Shape - const int filter_size_) ///< Filter size loaded by iterator - : layout(layout), - filter_size(filter_size_), - is_convolution(problem_size.mode == Mode::kConvolution){} -}; - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h deleted file mode 100644 index 012e306d800c3bcd62c1322a217d643d2ae38fd5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h +++ /dev/null @@ -1,314 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template > -class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { - public: - // - // Types - // - - using Shape = Shape_; - using OutputTileShape = OutputTileShape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - // Compilation value of stride , dialtion and activation shape - using StrideShape = StrideShape_; - using DilationShape = DilationShape_; - using ActivationShape = ActivationShape_; - - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * - sizeof_bits::value / 8; - - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); - - static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); - static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); - - // - // Parameters structure - // - - using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; - - private: - Conv2dProblemSize const &problem_size_; - Params const ¶ms_; - char const *pointer_; - - // Base channels for current threadblock - int base_c_; - // Base activation index for current threadblock - int offset_intial_npq_; - // Base activation coord for current threadblock - TensorCoord activatioin_base_; - // Intial thread positioin - int offset_initial_hwc_; - // Overall load instruction per thread. - int iterator_load_; - // thread loading position. - int iterator_hwc_; - // activation N is inside the Tensor or not - bool valid_n_; - - public: - - - CUTLASS_HOST_DEVICE - DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = - MatrixCoord() - ) - : params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - offset_intial_npq_(threadblock_offset.row()), - offset_initial_hwc_(thread_idx), - iterator_load_(0) { - - base_c_ = threadblock_offset.column(); - - set_iteration_index(0); - - set_activation_coord(offset_intial_npq_); - - } - - CUTLASS_HOST_DEVICE - void set_activation_coord(int offset_npq) { - int offset_inital_n, offset_inital_p, offset_inital_q; - int residual; - - params_.pq_divmod(offset_inital_n, residual, offset_npq); - params_.q_divmod(offset_inital_p, offset_inital_q, residual); - - int base_n = offset_inital_n; - - int base_h = - offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; - - int base_w = - offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; - - activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); - - valid_n_ = activatioin_base_.n() < problem_size_.N; - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params( - problem_size, - layout, - {Shape::kRow, Shape::kColumn}, - {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, - kActivationSize); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; - iterator_load_ = index; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // Go to next threadblock - offset_intial_npq_ += problem_size_.split_k_slices; - - set_iteration_index(0); - - set_activation_coord(offset_intial_npq_); - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; - int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; - int h = next / ActivationShape::kW; - int w = next % ActivationShape::kW; - - c = c * AccessType::kElements; - - return activatioin_base_ + TensorCoord(0, h, w, c); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - bool valid_c = coord.c() < problem_size_.C; - bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; - bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; - return valid_n_ ? valid_c & valid_h & valid_w : 0; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - AccessType const *ptr = - reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { - - ++iterator_load_; - iterator_hwc_ += ThreadMap::kThreads; - - if (iterator_load_ < ThreadMap::Iterations::kCount) { - return *this; - } - - iterator_load_ = 0; - iterator_hwc_ = offset_initial_hwc_; - - return *this; - } - - /// Determines the activation size loaded by iterator - CUTLASS_HOST_DEVICE - int get_load_size() { - return kActivationSize; - } - - /// Determines the iterations needed - CUTLASS_HOST_DEVICE - int get_iteration_num() { - return ThreadMap::Iterations::kCount; - } - - /// Determines whether the Depthwise fprop can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check stride and dilation constraint - if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { - return Status::kErrorInvalidProblem; - } - - if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { - return Status::kErrorInvalidProblem; - } - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h deleted file mode 100644 index b8ae9b9312c79f88715fd8cd1efebb2dad8a76f1..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h +++ /dev/null @@ -1,291 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template > -class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { - public: - // - // Types - // - - using Shape = Shape_; - using OutputTileShape = OutputTileShape_; - using Element = Element_; - using Layout = Layout_; - using TensorCoord = typename Layout::TensorCoord; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); - - static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); - static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); - - // - // Parameters structure - // - - using Params = Depthwise2dFpropDirectConvParams; - - private: - Conv2dProblemSize const &problem_size_; - Params const ¶ms_; - char const *pointer_; - - // Base channels for current threadblock - int base_c_; - // Base activation index for current threadblock - int offset_intial_npq_; - // Base activation coord for current threadblock - TensorCoord activatioin_base_; - // Intial thread positioin - int offset_initial_hwc_; - // Overall load instruction per thread. - int iterator_load_; - // thread loading position. - int iterator_hwc_; - // Number of loads for activations tensor X. - const int number_of_loads_; - - public: - - - CUTLASS_HOST_DEVICE - DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = - MatrixCoord() - ) - : params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - offset_intial_npq_(threadblock_offset.row()), - offset_initial_hwc_(thread_idx), - iterator_load_(0), - number_of_loads_(params.activation_load_count) { - - base_c_ = threadblock_offset.column(); - - set_activation_coord(offset_intial_npq_); - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - void set_activation_coord(int offset_npq) { - int offset_inital_n, offset_inital_p, offset_inital_q; - int residual; - - params_.pq_divmod(offset_inital_n, residual, offset_npq); - params_.q_divmod(offset_inital_p, offset_inital_q, residual); - - int base_n = offset_inital_n; - - int base_h = - offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; - - int base_w = - offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; - - activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params( - problem_size, - layout, - {Shape::kRow, Shape::kColumn}, - {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, - sizeof_bits::value, - ThreadMap::kThreads, - ThreadMap::Detail::ShapeVec::kContiguous, - ThreadMap::kElementsPerAccess); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; - iterator_load_ = index; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - CUTLASS_HOST_DEVICE - void advance() { - // Go to next threadblock - offset_intial_npq_ += problem_size_.split_k_slices; - - set_activation_coord(offset_intial_npq_); - } - - /// Returns the coordinate in the activations tensor X that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; - int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; - int h, w; - params_.activation_tile_w_divmod(h, w, next) ; - - c = c * AccessType::kElements; - - return activatioin_base_ + TensorCoord(0, h, w, c); - } - - /// Returns true if the current coordinate is within the activations tensor X - CUTLASS_HOST_DEVICE - bool valid() const { - TensorCoord coord = at(); - - return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - TensorCoord coord = at(); - LongIndex offset = params_.layout(coord); - - AccessType const *ptr = - reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - - return ptr; - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { - - ++iterator_load_; - iterator_hwc_ += ThreadMap::kThreads; - - if (iterator_load_ < number_of_loads_) { - return *this; - } - - iterator_load_ = 0; - iterator_hwc_ = offset_initial_hwc_; - - return *this; - } - - /// Determines the activation size loaded by iterator - CUTLASS_HOST_DEVICE - int get_load_size() { - return params_.activation_size; - } - - /// Determines the iterations needed - CUTLASS_HOST_DEVICE - int get_iteration_num() { - return number_of_loads_; - } - - /// Determines whether the Depthwise fprop can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h deleted file mode 100644 index 846f1f3aeb269edc67b5e2c02db3f05993172025..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h +++ /dev/null @@ -1,551 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/conv/threadblock/depthwise_mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Epilogue stores the data into global memory - typename Epilogue_, - /// iterator implementation variants - conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, - /// Used for partial specialization - typename Enable = bool> -class DepthwiseFpropDirectConvMultipleStage : - public DepthwiseDirectConvMmaBase { -public: - ///< Base class - using Base = DepthwiseDirectConvMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Policy describing tuning details - using Policy = Policy_; - - using Epilogue = Epilogue_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; - - // - // Dependent types - // - - /// Fragment of accumulator tile - - using ElementC = typename Policy::Operator::ElementC; - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - DepthwiseFpropDirectConvMultipleStage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, - IteratorB &iterator_B, - int group_start_A = 0, - int group_start_B = 0) { - if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { - // Number of iterators is a static value. - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - ++this->smem_iterator_A_; - } - } else { - // Number of iterators is a runtime value. - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - ++this->smem_iterator_A_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA &iterator_A, - ///< Params of global memory iterator - typename IteratorA::Params const &iterator_a_params, - ///< iterator over B operand in global memory - IteratorB &iterator_B, - ///< Params of global memory iterator - typename IteratorB::Params const &iterator_b_params, - ///< initial value of accumulator - FragmentC const &src_accum, - /// Epilogue - Epilogue &epilogue, - ///< Output operator - typename Epilogue::OutputOp const &output_op, - ///< Tile iterator for destination - typename Epilogue::OutputTileIterator &destination_iterator, - ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - typename Epilogue::OutputTileIterator &source_iterator, - - int split_k_slices = 1 - ) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - - if (stage == 0) { - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - } - - if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ - // Number of iterators is compilation static. - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - } else { - // Number of iterators is a runtime value. - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); - this->smem_iterator_A_.set_iteration_index(0); - - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - // Move to the next stage - iterator_A.advance(); - - this->smem_iterator_A_.add_tile_offset({1, 0}); - - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - } - - ///////////////////////////////////////////////////////////////////////////// - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); - - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // - // Mainloop - // - - unsigned int iterations = 0; - constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. - - accum.clear(); - - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { - if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k == 0) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - copy_tiles_and_advance( - iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k < Base::kWarpGemmIterations) { - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - } - - if (warp_mma_k + 1 == inner_loop_iterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - if (warp_mma_k + 2 == inner_loop_iterations) { - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages of cp.async have committed - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next cta - iterator_A.advance(); - - this->smem_iterator_A_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); - - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); - smem_read_stage_idx = 0; - } else { - this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); - ++smem_read_stage_idx; - } - - if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { - this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); - } - - // goback to start position. B has no multiple stage - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); - - --gemm_k_iterations; - } - } - - // - // Epilogue - // - int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); - - destination_iterator.set_tile_index(iterations * split_k_slices); - - source_iterator.set_tile_index(iterations * split_k_slices); - - epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); - - ++iterations; - } - - // Insert fence and wait for all outstanding cp.async operations to commit. - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h deleted file mode 100644 index 1035fda375787cc211929854b662d1ccb7a809ae..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h +++ /dev/null @@ -1,261 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) - matrix from memory. - - This iterator assumes TensorNHWC layout of tensors in Global Memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/threadblock/conv2d_params.h" -#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -template > -class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { -public: - // - // Types - // - - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - using TensorRef = cutlass::TensorRef; - using TensorCoord = typename Layout::TensorCoord; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; - static int const kConvDim = 2; - using ConvProblemSize = typename conv::Conv2dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * - sizeof_bits::value / 8; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - // - // Simplifying assertions - // - static_assert(ThreadMap::Iterations::kContiguous == 1, - "Require Iterations::kContiguous == 1"); - - // - // Parameters structure - // - using Params = Depthwise2dFpropDirectConvFilterIteratorParams; - - protected: - - Conv2dProblemSize const &problem_size_; - Params const ¶ms_; - LongIndex iteration_contiguous_; - LongIndex iteration_strided_; - LongIndex iteration_vector_; - char const *pointer_; - - int filter_k_; - int offset_trs_[ThreadMap::Iterations::kStrided]; - -public: - - - - CUTLASS_HOST_DEVICE - DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( - Params const ¶ms, - Conv2dProblemSize const &problem_size, - Element const *ptr, - int thread_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - params_(params), - problem_size_(problem_size), - pointer_(reinterpret_cast(ptr)), - filter_k_(0) { - - layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); - - filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; - } - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(Index index) { - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset * 8 / sizeof_bits::value; - } - - CUTLASS_HOST_DEVICE - void advance() { - // Do nothing because the filter is persistent in the SMEM - } - - /// Returns the coordinate in the filter tensor W that is currently pointed to - /// by the iterator. - CUTLASS_HOST_DEVICE - TensorCoord at() const { - - int k = filter_k_ + iteration_vector_ * AccessType::kElements; - int trs = offset_trs_[iteration_strided_]; - - return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix - } - - /// Returns true if the current coordinate is within the activations tensor W - CUTLASS_HOST_DEVICE - bool valid() const { - - TensorCoord coord = at(); - - return coord.n() < problem_size_.K && - coord.h() < Shape::kColumn; - } - - /// Returns a pointer to the vector starting at the current coordinate - CUTLASS_HOST_DEVICE - AccessType const *get() const { - TensorCoord coord = at(); - int64_t offset = coord.n(); - if (params_.is_convolution) { - offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; - } else { - offset += coord.h() * problem_size_.K; - } - - return reinterpret_cast(pointer_ + - offset * sizeof_bits::value / 8); - } - - /// Increments to the next memory access - CUTLASS_HOST_DEVICE - DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - iteration_vector_ = 0; - - ++iteration_contiguous_; - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - iteration_contiguous_ = 0; - - ++iteration_strided_; - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - iteration_strided_ = 0; - - return *this; - } - - /// Determines the filter size loaded by iterator - CUTLASS_HOST_DEVICE - int get_load_size() { - return kFilterSize; - } - - /// Determines whether the Implicit GEMM can execute the given problem. - CUTLASS_HOST_DEVICE - static Status can_implement(Conv2dProblemSize const &problem_size) { - - // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { - return Status::kErrorInvalidProblem; - } - - // check whether runtime filter size is same as templated filter size. - if ((problem_size.R * problem_size.S) != Shape::kColumn) { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h deleted file mode 100644 index 30d13e9087e4b47384a43b9381036f668d581808..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h +++ /dev/null @@ -1,336 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Transformation applied to A operand - typename TransformA_ = NumericArrayConverter< - typename SmemIteratorA_::Element, - typename IteratorA_::Element, - IteratorA_::Fragment::kElements>, - /// - /// Transformation applied to A operand - typename TransformB_ = NumericArrayConverter< - typename SmemIteratorB_::Element, - typename IteratorB_::Element, - IteratorB_::Fragment::kElements>, - /// Used for partial specialization - typename Enable = bool -> -class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { -public: - - ///< Base class - using Base = gemm::threadblock::MmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - using TransformA = TransformA_; - using TransformB = TransformB_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); - -private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - -protected: - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - DepthwiseFpropPipelined( - typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum, ///< source accumulator tile - int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel - TransformA transform_A = TransformA(), ///< transformation applied to A fragment - TransformB transform_B = TransformB()) { ///< transformation applied to B fragment - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - // Depthwise specific - int channel_start_index = 0; - int rs_plane_idx = 0; - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tightest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ - // Reset interation index. - iterator_B.set_iteration_index(0); - } - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - __syncthreads(); - - if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ - // Move to next set of filter groups. - channel_start_index += Base::kWarpGemmIterations; - } - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, - 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k == 0) { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - } - - warp_mma(accum, warp_frag_A[warp_mma_k % 2], - warp_frag_B[warp_mma_k % 2], accum); - } - - rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); - - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h deleted file mode 100644 index 44dafcb5fa4f099e8070a9c8d271c4048128ceac..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h +++ /dev/null @@ -1,229 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a directconv threadblock-scoped Depthwise kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Policy object describing MmaTensorOp -template < - /// Warp-level GEMM operator (concept: gemm::warp::Mma) - typename Operator_, - /// Padding used for A operand in shared memory (concept: MatrixShape) - typename SmemPaddingA_, - /// Padding used for B operand in shared memory (concept: MatrixShape) - typename SmemPaddingB_, - /// - typename ThreadMapA_, - /// - typename ThreadMapB_, - /// Number of partitions of K dimension of GEMM - int PartitionsK = 1> -struct DepthwiseDirectConvMmaPolicy { - /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) - using Operator = Operator_; - - /// Padding used for A operand in shared memory - using SmemPaddingA = SmemPaddingA_; - - /// Padding used for B operand in shared memory - using SmemPaddingB = SmemPaddingB_; - - using ThreadMapA = ThreadMapA_; - using ThreadMapB = ThreadMapB_; - - /// Number of partitions of K dimension - static int const kPartitionsK = PartitionsK; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class DepthwiseDirectConvMmaBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = cutlass::gemm:: - GemmShape; - - /// Number of warp-level GEMM oeprations - /// kWarpGemmIterations could be even and odd. - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - static_assert(kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape<1, // Not determined at compile-time :( - Shape::kN + Policy::SmemPaddingA::kRow>; - - /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; // Tile N = 64? - - public: - // - // Data members - // - - // Let persistent B matrix in front of dynamic matrix A - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer for A operand - /// Not be determined at compile-time -- Just to get a Smem start address. - AlignedBuffer operand_A; - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - }; - - protected: - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - - public: - /// Construct from tensor references - CUTLASS_DEVICE - DepthwiseDirectConvMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h deleted file mode 100644 index 9e3cc417d4cc3169724f9e5db9e82fa093121fae..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h +++ /dev/null @@ -1,952 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/warp/mma_depthwise_simt.h" - -#include "cutlass/gemm/threadblock/mma_pipelined.h" -#include "cutlass/gemm/threadblock/mma_singlestage.h" - -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/conv/threadblock/depthwise_mma_base.h" - -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" - -#include "cutlass/arch/cache_operation.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -namespace detail { -// -// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by -// each partitions within warp. -// The goal is for each thread's tile of elements to be as square as -// possible for performance (4x4 will be faster than 2x8). -template // The number of partitions within the warp -struct SimtWarpShape { - // kP * kQ * WarpNumThreadsM = WarpShapeM - // If needed, enable more specializations. -}; -template <> -struct SimtWarpShape<4, 4> { - static constexpr int kP = 1; - static constexpr int kQ = 1; -}; - -template <> -struct SimtWarpShape<4, 2> { - static constexpr int kP = 2; - static constexpr int kQ = 1; -}; - -template <> -struct SimtWarpShape<4, 1> { - static constexpr int kP = 2; - static constexpr int kQ = 2; -}; - -template <> -struct SimtWarpShape<8, 1> { - static constexpr int kP = 2; - static constexpr int kQ = 4; -}; -template <> -struct SimtWarpShape<8, 2> { - static constexpr int kP = 2; - static constexpr int kQ = 2; -}; -template <> -struct SimtWarpShape<8, 4> { - static constexpr int kP = 1; - static constexpr int kQ = 2; -}; - -template <> -struct SimtWarpShape<16, 1> { - static constexpr int kP = 4; - static constexpr int kQ = 4; -}; -template <> -struct SimtWarpShape<16, 2> { - static constexpr int kP = 2; - static constexpr int kQ = 4; -}; -template <> -struct SimtWarpShape<16, 4> { - static constexpr int kP = 2; - static constexpr int kQ = 2; -}; - -template -struct SimtWarpShape<25, WarpNumThreadsM> { - static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); - static constexpr int kP = 5; - static constexpr int kQ = 5; -}; - -template <> -struct SimtWarpShape<32, 1> { - static constexpr int kP = 4; - static constexpr int kQ = 8; -}; - -template <> -struct SimtWarpShape<32, 2> { - static constexpr int kP = 4; - static constexpr int kQ = 4; -}; - -template <> -struct SimtWarpShape<32, 4> { - static constexpr int kP = 2; - static constexpr int kQ = 4; -}; - -} // namespace detail - -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeA_ = 0, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeB_ = 0, - /// Number of stages - int Stages = 2, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global, - /// per-element transformation for elements of A - ComplexTransform TransformA = ComplexTransform::kNone, - /// per-element transformation for elements of B - ComplexTransform TransformB = ComplexTransform::kNone, - bool IsComplex = false // (is_complex::value || is_complex::value) -> -struct DepthwiseMmaCoreWithLaneAccessSize; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of threadblock-scoped output tile - typename ThreadBlockOutputShape, - /// Shape of filter shape per threadblock - typename FilterShape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeA_ = 0, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeB_ = 0, - /// Number of stages - int Stages = 2, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Iterator algo type - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - /// Stride ( MatrixShape ) - typename StrideShape = cutlass::MatrixShape<-1, -1>, - /// Dilation ( MatrixShape ) - typename DilationShape = cutlass::MatrixShape<-1, -1>, - /// Activation Shape loaded by threadblock - typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global, - /// per-element transformation for elements of A - ComplexTransform TransformA = ComplexTransform::kNone, - /// per-element transformation for elements of B - ComplexTransform TransformB = ComplexTransform::kNone, - bool IsComplex = false // (is_complex::value || is_complex::value) -> -struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// per-element transformation for elements of A - ComplexTransform TransformA, - /// per-element transformation for elements of B - ComplexTransform TransformB, - bool IsComplex -> -struct DepthwiseMmaCoreWithLaneAccessSize< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, - CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -> : cutlass::gemm::threadblock::DefaultMmaCore< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - OperatorClass, Stages, Operator, AccumulatorsInRowMajor, - CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -> {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Size of a warp-scoped per thread access (a value of -1 indicates the default) - int kLaneAccessSizeA_, - /// Size of a warp-scoped per thread access (a value of -1 indicates the default) - int kLaneAccessSizeB_, - /// Operation performed by GEMM - typename Operator_> -struct DepthwiseMmaCoreWithLaneAccessSize, - ElementA_, - layout::RowMajor, - ElementB_, - layout::ColumnMajor, - ElementC_, - LayoutC_, - arch::OpClassSimt, - kLaneAccessSizeA_, - kLaneAccessSizeB_, - 2, - Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, - ElementA_, - layout::RowMajor, - ElementB_, - layout::ColumnMajor, - ElementC_, - LayoutC_, - arch::OpClassSimt, - 2, - Operator_> { - using Base = cutlass::gemm::threadblock::DefaultMmaCore, - ElementA_, - layout::RowMajor, - ElementB_, - layout::ColumnMajor, - ElementC_, - LayoutC_, - arch::OpClassSimt, - 2, - Operator_>; - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - static int const kLaneAccessSizeA = kLaneAccessSizeA_; - static int const kLaneAccessSizeB = kLaneAccessSizeB_; - - // Divisility requirements - static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, - "Size of a warp-scoped per thread access should be larger then ZERO" ); - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = typename Base::WarpCount; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; - - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory are same as base class - // - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; - static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< - MmaWarpSimt, - MatrixShape, // skew for A matrix to avoid SMEM bank conflicts - MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) - typename ThreadBlockOutputShape_, - /// Shape of filter shape per threadblock - typename FilterShape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeA_, - /// Number of stages - int Stages_, - /// Operation performed by GEMM - typename Operator_> -struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, - ElementA_, - layout::RowMajor, - ElementB_, - layout::ColumnMajor, - ElementC_, - LayoutC_, - arch::OpClassSimt, - kLaneAccessSizeA_, - 128, - Stages_, - Operator_> { - using Shape = Shape_; - using FilterShape = FilterShape_; - using WarpShape = WarpShape_; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - static int const kLaneAccessSizeB = 128; - - // Divisility requirements - static_assert( kLaneAccessSizeB > 0, - "Size of a warp-scoped per thread access should be larger then ZERO" ); - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = cutlass::gemm::GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - 1 - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // For Gmem load - static int const kElementsPerAccessA = 128 / sizeof_bits::value; - static int const kElementsPerAccessB = 128 / sizeof_bits::value; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajor; - using SmemLayoutB = layout::RowMajor; - - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. - kThreads, - kElementsPerAccessA - >; - - /// ThreadMap of iterator A - using SmemThreadMapA = IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< - MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value - ElementA, - SmemLayoutA, - 0, - SmemThreadMapA, // was IteratorThreadMapA - true // Dynamic iterations. - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccessB - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB, // was IteratorThreadMapB - false // static iterations. - >; - - // - // Warp-level matrix multiply operator - // - // Groups per threads - // Fp32: 2 groups - // Fp16: 2 groups - static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; - // Define the warp-level op - static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); - static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; - - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - - // Get output P, Q per thread - static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; - static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; - - static const int LaneLayout = 1; - static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; - static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); - - // Define the output tile computed by each thread - using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; - - // Fetch the channel with same access size - static const int LaneM = LaneN; - - // No paddings - static int const kPaddingM = 0; - static int const kPaddingN = 0; - - static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> - FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape - ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> - ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< - MmaWarpSimt, - MatrixShape, // skew for A matrix to avoid SMEM bank conflicts - MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts - IteratorThreadMapA, - IteratorThreadMapB, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) - typename ThreadBlockOutputShape_, - /// Shape of filter shape per threadblock - typename FilterShape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Size of a warp-scoped per thread access - int kLaneAccessSizeA_, - /// Number of stages - int Stages_, - /// Operation performed by GEMM - typename Operator_, - /// Stride ( MatrixShape ) - typename StrideShape_, - /// Dilation ( MatrixShape ) - typename DilationShape_, - /// Activation Shape loaded by threadblock - typename ActivationShape_> -struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, - ElementA_, - layout::RowMajor, - ElementB_, - layout::ColumnMajor, - ElementC_, - LayoutC_, - arch::OpClassSimt, - kLaneAccessSizeA_, - 128, - Stages_, - Operator_, - IteratorAlgorithm::kFixedStrideDilation, - StrideShape_, - DilationShape_, - ActivationShape_> { - using Shape = Shape_; - using FilterShape = FilterShape_; - using WarpShape = WarpShape_; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - using StrideShape = StrideShape_; - using DilationShape = DilationShape_; - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - using ActivationShape = ActivationShape_; - - static int const kLaneAccessSizeB = 128; - - // Divisility requirements - static_assert( kLaneAccessSizeB > 0, - "Size of a warp-scoped per thread access should be larger then ZERO" ); - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = cutlass::gemm::GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - 1 - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // For Gmem load - static int const kElementsPerAccessA = 128 / sizeof_bits::value; - static int const kElementsPerAccessB = 128 / sizeof_bits::value; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajor; - using SmemLayoutB = layout::RowMajor; - - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccessA - >; - - /// ThreadMap of iterator A - using SmemThreadMapA = IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - SmemThreadMapA, // was IteratorThreadMapA - false // static iterations. - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccessB - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB, // was IteratorThreadMapB - false // static iterations. - >; - - // - // Warp-level matrix multiply operator - // - // Groups per threads - // Fp32: 2 groups - // Fp16: 2 groups - static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; - // Define the warp-level op - static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); - static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; - - static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; - static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; - - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - - static const int LaneLayout = 1; - static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; - static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); - - // Define the output tile computed by each thread - using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; - - // Fetch the channel with same access size - static const int LaneM = LaneN; - - // No paddings - static int const kPaddingM = 0; - static int const kPaddingN = 0; - - static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> - FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape - ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> - ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type - StrideShape, /// Stride ( MatrixShape ) - DilationShape, /// Dilation ( MatrixShape ) - ActivationShape /// Activation Shape loaded by threadblock - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< - MmaWarpSimt, - MatrixShape, // skew for A matrix to avoid SMEM bank conflicts - MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts - IteratorThreadMapA, - IteratorThreadMapB, - WarpCount::kK - >; -}; -} // namespace threadblock -} // namespace conv -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h deleted file mode 100644 index 482a52fe63209650546811aa24cafcc7419e7479..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h +++ /dev/null @@ -1,802 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multistage threadblock-scoped fused activation's - scale+bias+relu and Implicit GEMM Convolution kernel. - - The original implicit gemm will store out-of-bound data as zeroes in the - shared memory because zeros into the tensor core, zeroes out of the tensor - cores. The result is remained the same. When fusing scale+bias+relu - into the mainloop, it is no longer true because - - 0 x scale + bias = bias - - which is no longer always 0. So, instead of storing zeroes, this fused - kernel stores the out-of-bound data as a special NaN (0x7eff), when applying - scale+bias+relu, the code is like - - if (data == 0x7eff) - data = 0; - else - data = scale+bias+relu(data, scale, bias); - - See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the - elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -#include "cutlass/conv/warp/scale_bias_relu_transform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Element type of scale and bias vectors - typename ElementScaleBias_, - /// Layout of scale and bias vectors - typename LayoutScaleBias_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// WarpIterator to load Scale or Bias vector from the shared memory - typename WarpIteratorScaleBias_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaFpropFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Element type of scale and bias vectors - using ElementScaleBias = ElementScaleBias_; - - /// Layout of scale and bias vectors - using LayoutScaleBias = LayoutScaleBias_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< WarpIterator to load Scale or Bias vector from the shared memory - using WarpIteratorScaleBias = WarpIteratorScaleBias_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = cutlass::gemm::GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the scale and bias vectors - using TensorRefScaleBias = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - static_assert(kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - static_assert((kWarpGemmIterations % 2) == 0, - "Inner loop iteration must be an even number."); - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the A scale and bias vectors in shared memory - using ShapeScaleBias = - MatrixShape<1 + Policy::SmemPaddingA::kRow, - 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer for A operand Scale and Bias - AlignedBuffer operand_A_scale_bias; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a layout object for the A scale and bias vectors - CUTLASS_DEVICE - static LayoutScaleBias LayoutScaleBias() { - return LayoutScaleBias::packed( - {ShapeScaleBias::kRow, ShapeScaleBias::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - - /// Returns a TensorRef to the A operand Scale vector - CUTLASS_HOST_DEVICE - TensorRefScaleBias operand_A_scale_bias_ref() { - return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of A operand scale and bias vector - /// from shared memory - WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaFpropFusionBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_A_scale_bias_( - shared_storage.operand_A_scale_bias_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterates over vectors of scale and bias vector in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorScaleBias_, - /// Iterates over vectors of scale and bias vector in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorScaleBias_, - /// Cache operation for scale/bias operand - cutlass::arch::CacheOperation::Kind CacheOpScaleBias, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// WarpIterator to load Scale or Bias vector from the shared memory - typename WarpIteratorScaleBias_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class ImplicitGemmFpropFusionMultistage - : public MmaFpropFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of the scale and bias vectors in global memory - using IteratorScaleBias = IteratorScaleBias_; - ///< WarpIterator to load Scale or Bias vector from the shared memory - using WarpIteratorScaleBias = WarpIteratorScaleBias_; - ///< Policy describing tuning details - using Policy = Policy_; - ///< Base class - using Base = MmaFpropFusionBase; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScaleBias = SmemIteratorScaleBias_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias = - CacheOpScaleBias; - - // - // Dependent types - // - - /// Fragment of accumulator tile - - using ElementC = typename Policy::Operator::ElementC; - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpLoadedFragmentScaleBias = - typename WarpIteratorScaleBias::Fragment; - - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory - SmemIteratorScaleBias smem_iterator_A_scale_bias_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - ImplicitGemmFpropFusionMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(), - thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_A_scale_bias_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, - IteratorScaleBias &iterator_A_scale_bias, - IteratorB &iterator_B, int group_start_A = 0, - int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; - - // Uses nan fill for out of bound data - cutlass::arch::cp_async_nan( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - - ++this->smem_iterator_A_; - } - } - - // Async Copy for operand A scale and bias vector. Scale and bias vectors - // are small. One iteration is enough. - if (group_start_A == 0) { - typename IteratorScaleBias::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_scale_bias_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorScaleBias::kElementsPerAccess / 8; - - cutlass::arch::cp_async( - dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); - } - - iterator_B.set_iteration_index(group_start_B); - - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale and bias vectors in global memory - IteratorScaleBias iterator_A_scale_bias, - ///< initial value of accumulator - FragmentC const &src_accum, - ///< number of iterations per channel - int gemm_k_iterations_per_channel = 0, - ///< Imaginary strides used for planar-complex only - ignored here - int64_t imag_stride_A = 0, - int64_t imag_stride_B = 0) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; - - // Uses Nan fill for out of bound data - cutlass::arch::cp_async_nan( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - ++this->smem_iterator_A_; - } - - // Async Copy for operand A scale and bias vectors. Scale and bias - // vectors are small. One iteration is enough. - { - typename IteratorScaleBias::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_scale_bias_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorScaleBias::kElementsPerAccess / 8; - - cutlass::arch::cp_async( - dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.advance(); - iterator_A_scale_bias.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - cutlass::conv::warp::FpropScaleBiasReluTransform - elementwise_transform; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_A_scale_bias_.load( - warp_loaded_frag_A_scale_bias[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_A_scale_bias_; - ++this->warp_tile_iterator_B_; - - // Start issuing the first group of the next stage outside of the mainloop - copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - elementwise_transform(warp_transformed_frag_A[0], - warp_loaded_frag_A_scale_bias[0]); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_scale_bias_.set_kgroup_index( - (warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_A_scale_bias_.load( - warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_A_scale_bias_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_loaded_frag_A_scale_bias[warp_mma_k % 2]); - } - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - } else { - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - } - - copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - elementwise_transform( - warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages of cp.async have committed - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.advance(); - iterator_A_scale_bias.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_A_scale_bias_.add_tile_offset( - {0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_A_scale_bias_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - } - } - - } - - // Insert fence and wait for all outstanding cp.async operations to commit. - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h deleted file mode 100644 index 6c9c4792e289824afd1a761f5b7b4cc5972f167a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ /dev/null @@ -1,539 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class ImplicitGemmMultistage : - public gemm::threadblock::MmaBase { -public: - ///< Base class - using Base = gemm::threadblock::MmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Policy describing tuning details - using Policy = Policy_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - - using ElementC = typename Policy::Operator::ElementC; - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical - // accuracy, where each mainloop iteration first accumulates into a temporary - // set of freshly-cleared accumulators, which are subsequently added to the - // final accumulator set. - static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - ImplicitGemmMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA &iterator_A, IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum, - ///< number of iterations per channel - int gemm_k_iterations_per_channel = 0, - ///< Imaginary strides used for planar-complex only - ignored here - int64_t imag_stride_A = 0, - int64_t imag_stride_B = 0) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - // Start issuing the first group of the next stage outside of the mainloop - copy_tiles_and_advance(iterator_A, iterator_B); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // tf32x3 kernels use staging accumulation. warp_mma uses a temporary - // accumulator and this temporary accumulator is added to the final - // accumulator once in every mainloop iteration. - plus plus_accum; - - FragmentC tmp_accum; - - if (Detail::kStagedAccumulation) { - tmp_accum.clear(); - } - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - } else { - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - } - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - - if (Detail::kStagedAccumulation) { - warp_mma( - tmp_accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - tmp_accum - ); - - if (warp_mma_k == 0) { - accum = plus_accum(accum, tmp_accum); - tmp_accum.clear(); - } - } else { - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - } - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages of cp.async have committed - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - } - } - - } - - if (Detail::kStagedAccumulation) { - accum = plus_accum(accum, tmp_accum); - } - - // Insert fence and wait for all outstanding cp.async operations to commit. - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h deleted file mode 100644 index 45e27949665f797ba28afcd5f1cf98007c56eac9..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h +++ /dev/null @@ -1,320 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Transformation applied to A operand - typename TransformA_ = NumericArrayConverter< - typename SmemIteratorA_::Element, - typename IteratorA_::Element, - IteratorA_::Fragment::kElements>, - /// - /// Transformation applied to A operand - typename TransformB_ = NumericArrayConverter< - typename SmemIteratorB_::Element, - typename IteratorB_::Element, - IteratorB_::Fragment::kElements>, - /// Used for partial specialization - typename Enable = bool -> -class ImplicitGemmPipelined : public gemm::threadblock::MmaBase { -public: - - ///< Base class - using Base = gemm::threadblock::MmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - using TransformA = TransformA_; - using TransformB = TransformB_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); - -private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - -protected: - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - ImplicitGemmPipelined( - typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum, ///< source accumulator tile - int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel - TransformA transform_A = TransformA(), ///< transformation applied to A fragment - TransformB transform_B = TransformB()) { ///< transformation applied to B fragment - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tightest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, - 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k == 0) { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - } - - warp_mma(accum, warp_frag_A[warp_mma_k % 2], - warp_frag_B[warp_mma_k % 2], accum); - } - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h deleted file mode 100644 index 3be08c1ad90cf896b0b2191aa0c0a4a5a8c5b033..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h +++ /dev/null @@ -1,729 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and - Implicit GEMM Convolution kernel. - - The original implicit gemm will store out-of-bound data as zeroes in the - shared memory because zeros into the tensor core, zeroes out of the tensor - cores. The result is remained the same. When fusing scale+bias+relu - into the mainloop, it is no longer true because - - 0 x scale + bias = bias - - which is no longer always 0. So, instead of storing zeroes, this fused - kernel stores the out-of-bound data as a special NaN (0x7eff), when applying - scale+bias+relu, the code is like - - if (data == 0x7eff) - data = 0; - else - data = scale+bias+relu(data, scale, bias); - - The biggest difference compared with the fused Fprop and scale+bias+relu is - that scale and bias are loop invariant in Wgrad so that they only needs to - be loaded once before the mainloop. - - See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the - elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. - - -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -#include "cutlass/conv/warp/scale_bias_relu_transform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Element type of scale and bias vectors - typename ElementScaleBias_, - /// Layout of scale and bias vectors - typename LayoutScaleBias_, - /// Element type of scale and bias vectors - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaWgradFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Element type of scale and bias vectors - using ElementScaleBias = ElementScaleBias_; - - /// Layout of scale and bias vectors - using LayoutScaleBias = LayoutScaleBias_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = cutlass::gemm::GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - static_assert(kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - static_assert((kWarpGemmIterations % 2) == 0, - "Inner loop iteration must be an even number."); - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaWgradFusionBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterates over vectors of scale and bias vector in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorScaleBias_, - /// Iterates over vectors of scale and bias vector i - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class ImplicitGemmWgradFusionMultistage - : public MmaWgradFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of the scale and bias vectors in global memory - using IteratorScaleBias = IteratorScaleBias_; - ///< Policy describing tuning details - using Policy = Policy_; - ///< Base class - using Base = MmaWgradFusionBase; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - - using ElementC = typename Policy::Operator::ElementC; - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - static int const kBBufferSize = - ((sizeof(typename Operator::ElementC) == 4) && - ((platform::is_same::value && - platform::is_same::value)) && - (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) - ? 1 - : 2; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment; - - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - int warp_idx_m_; - - int warp_idx_n_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - ImplicitGemmWgradFusionMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; - warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, - IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - - iterator_A.set_iteration_index(group_start_A); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B); - - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; - - // Uses nan fill for out of bound data - cutlass::arch::cp_async_nan( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale and bias vectors in global memory - IteratorScaleBias iterator_B_scale_bias, - ///< initial value of accumulator - FragmentC const &src_accum, - ///< number of iterations per channel - int gemm_k_iterations_per_channel = 0, - ///< Imaginary strides used for planar-complex only - ignored here - int64_t imag_stride_A = 0, - int64_t imag_stride_B = 0) { - - // - // Prologue - // - - WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias; - iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_}); - iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; - - // Uses Nan fill for out of bound data - cutlass::arch::cp_async_nan( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - cutlass::conv::warp::WgradScaleBiasReluTransform - elementwise_transform; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - // Start issuing the first group of the next stage outside of the mainloop - copy_tiles_and_advance(iterator_A, iterator_B); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - elementwise_transform(warp_transformed_frag_B[0], - warp_loaded_frag_B_scale_bias); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - if (Detail::kBBufferSize == 2) { - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]); - ++this->warp_tile_iterator_A_; - } - - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize], - warp_loaded_frag_B[warp_mma_k % 2]); - - elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_B_scale_bias); - } - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - - if (Detail::kBBufferSize == 1) { - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - ++this->warp_tile_iterator_A_; - - } - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - elementwise_transform( - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_B_scale_bias); - } - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - } else { - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - } - - copy_tiles_and_advance(iterator_A, iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - // Inserts a fence to group cp.async instructions into stages. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages of cp.async have committed - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.advance(); - iterator_B.advance(); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - } - } - - } - - // Insert fence and wait for all outstanding cp.async operations to commit. - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h deleted file mode 100644 index dac642385cd445e9a36a2f4c6f6c9e51f309cb87..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h +++ /dev/null @@ -1,470 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates calculating the address and predicates to the load of scale and bias vectors. - - This iterator uses masks to guard out-of-bounds accesses. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedScaleBiasVectorAccessIterator -/// -template -class PredicatedScaleBiasVectorAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -/// -template -class PredicatedScaleBiasVectorAccessIterator { - public: - - using ThreadblockShape = ThreadblockShape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kElementsPerAccess = 128 / sizeof_bits::value; - static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; - - using AccessType = AlignedArray; - - using Params = PredicatedScaleBiasVectorAccessIteratorParams; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - int problem_size_trs; - int problem_size_c; - int filter_trs_; - - TensorCoord thread_offset_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv2dProblemSize const &problem_size, - /// Pointer to the start of the scale vector - ConstPointer scale_pointer, - /// Pointer to the start of the bias vector - ConstPointer bias_pointer, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - problem_size_trs(problem_size.R * problem_size.S), - problem_size_c(problem_size.C), - filter_trs_(0) { - pointer_ = (thread_id < kThreads) - ? reinterpret_cast( - const_cast(scale_pointer)) - : reinterpret_cast( - const_cast(bias_pointer)); - - // Per-thread offset in logical coordinates of tensor - int thread_base = (thread_id < kThreads) ? 0 : kThreads; - - thread_offset_ = - threadblock_offset + - TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); - - set_iteration_index(0); - } - - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv3dProblemSize const &problem_size, - /// Pointer to the start of the scale vector - ConstPointer scale_pointer, - /// Pointer to the start of the bias vector - ConstPointer bias_pointer, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - problem_size_trs(problem_size.T * problem_size.R * problem_size.S), - problem_size_c(problem_size.C), - filter_trs_(0) { - pointer_ = (thread_id < kThreads) - ? reinterpret_cast( - const_cast(scale_pointer)) - : reinterpret_cast( - const_cast(bias_pointer)); - - // Per-thread offset in logical coordinates of tensor - int thread_base = (thread_id < kThreads) ? 0 : kThreads; - - thread_offset_ = - threadblock_offset + - TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); - - set_iteration_index(0); - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv2dProblemSize const &problem_size, - /// Pointer to start of scale vector - ConstPointer scale_pointer, - /// Pointer to start of scale vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id) - : PredicatedScaleBiasVectorAccessIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv3dProblemSize const &problem_size, - /// Pointer to start of scale vector - ConstPointer scale_pointer, - /// Pointer to start of scale vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id) - : PredicatedScaleBiasVectorAccessIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) {} - - /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - thread_offset_ = - thread_offset_ + - TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - return reinterpret_cast( - pointer_ + - (thread_offset_.contiguous() * sizeof_bits::value / 8)); - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator &operator++() { - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - void advance() { - // moves to the next tile - ++filter_trs_; - if (filter_trs_ == problem_size_trs) { - filter_trs_ = 0; - add_tile_offset(TensorCoord(1, 0)); - } - } - - /// Increment and return an instance to self. - CUTLASS_DEVICE - PredicatedScaleBiasVectorAccessIterator operator++(int) { - PredicatedScaleBiasVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - uint32_t enabled = 0; - -#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11) - enabled = threadIdx.x < kThreads * 2; -#else - asm volatile( - "{\n" - " .reg .u32 tid_reg;\n" - " .reg .pred p;\n" - " mov.u32 tid_reg, %%tid.x;\n" - " setp.lt.u32 p, tid_reg, %1;\n" - " selp.u32 %0, 1, 0, p;\n" - "}\n" : "+r"(enabled) :"n"(kThreads * 2)); -#endif - - return ((thread_offset_.contiguous() < problem_size_c) && enabled); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedScaleBiasVectorAccessIterator { - public: - - using ThreadblockShape = ThreadblockShape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear>; - - using AccessType = typename UnderlyingIterator::AccessType; - static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; - - using Params = PredicatedScaleBiasVectorAccessIteratorParams; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Extent of tensor - Conv2dProblemSize const &problem_size, - ///< Pointer to the start of the scale vector - ConstPointer scale_pointer, - ///< Pointer to the start of the bias vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params, problem_size, scale_pointer, bias_pointer, - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Extent of tensor - Conv3dProblemSize const &problem_size, - ///< Pointer to the start of the scale vector - ConstPointer scale_pointer, - ///< Pointer to the start of the bias vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params, problem_size, scale_pointer, bias_pointer, - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Conv2dProblemSize const &problem_size, ///< Extent of tensor - ConstPointer scale_pointer, ///< Pointer to the start of the scale vector - ConstPointer bias_pointer, ///< Pointer to the start of the bias vector - int thread_id ///< ID of each participating thread - ) - : PredicatedScaleBiasVectorAccessIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Conv3dProblemSize const &problem_size, ///< Extent of tensor - ConstPointer scale_pointer, ///< Pointer to the start of the scale vector - ConstPointer bias_pointer, ///< Pointer to the start of the bias vector - int thread_id ///< ID of each participating thread - ) - : PredicatedScaleBiasVectorAccessIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// threadblock tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator operator++(int) { - PredicatedScaleBiasVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - void advance() { - iterator_.advance(); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h deleted file mode 100644 index e9844be9f000920fd82f18dc6dab5755611f08ea..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h +++ /dev/null @@ -1,371 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates calculating the address and predicates to the load of scale and bias vectors. - - This iterator uses masks to guard out-of-bounds accesses. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedScaleBiasVectorIterator -/// -template -class PredicatedScaleBiasVectorIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -/// -template -class PredicatedScaleBiasVectorIterator { - public: - - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kElementsPerAccess = 1; - - using AccessType = AlignedArray; - - static int const kIterations = WarpShape::kContiguous / 8; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; - - /// Parameters object is precomputed state and is host-constructible - using Params = Conv2dWgradActivationIteratorOptimizedParams; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - ConstPointer scale_pointer_; - ConstPointer bias_pointer_; - - /// Size of tensor - Conv2dProblemSize problem_size_; - - int32_t thread_offset_; - - // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k - int32_t filter_c_[kIterations]; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv2dProblemSize const &problem_size, - /// Pointer to the start of the scale vector - ConstPointer scale_pointer, - /// Pointer to the start of the bias vector - ConstPointer bias_pointer, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - problem_size_(problem_size), - scale_pointer_(scale_pointer), - bias_pointer_(bias_pointer) { - - thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; - } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Extent of tensor - Conv2dProblemSize const &problem_size, - /// Pointer to start of scale vector - ConstPointer scale_pointer, - /// Pointer to start of scale vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id) - : PredicatedScaleBiasVectorIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - - thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); - - CUTLASS_PRAGMA_UNROLL - for(int c = 0; c < kIterations; ++c) { - int rsc_offset = thread_offset_ + c * 8; - - int residual, tmp; - params_.sc_divmod(tmp, residual, rsc_offset); - params_.c_divmod(tmp, filter_c_[c], residual); - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - frag.fill(__float2half2_rn(0.0f)); - __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); - - // load scale - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - - cutlass::arch::global_load< - __half, - sizeof(AccessType) - >( - frag_ptr[c * 2].x, - scale_pointer_ + filter_c_[c], - true - ); - } - - // load bias - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - - cutlass::arch::global_load< - __half, - sizeof(AccessType) - >( - frag_ptr[c * 2 + 1].x, - bias_pointer_ + filter_c_[c], - true - ); - } - - // duplicate scale - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - frag_ptr[c * 2].y = frag_ptr[c * 2].x; - } - - // duplicate bias - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedScaleBiasVectorIterator { - public: - - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedScaleBiasVectorIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear>; - - using AccessType = typename UnderlyingIterator::AccessType; - static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; - using Fragment = typename UnderlyingIterator::Fragment; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedScaleBiasVectorIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Conv2dProblemSize const &problem_size, Layout const &layout) - : params_(problem_size, layout::TensorNHWC(0, 0, 0)){}; - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Extent of tensor - Conv2dProblemSize const &problem_size, - ///< Pointer to the start of the scale vector - ConstPointer scale_pointer, - ///< Pointer to the start of the bias vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, problem_size, scale_pointer, bias_pointer, - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - Params const ¶ms, ///< Precomputed parameters object - Conv2dProblemSize const &problem_size, ///< Extent of tensor - ConstPointer scale_pointer, ///< Pointer to the start of the scale vector - ConstPointer bias_pointer, ///< Pointer to the start of the bias vector - int thread_id ///< ID of each participating thread - ) - : PredicatedScaleBiasVectorIterator(params, problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// threadblock tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - iterator_.load(frag); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h deleted file mode 100644 index 0c5aed6dba0fa206fcab9545eeeb165558cb724a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h +++ /dev/null @@ -1,193 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implements several possible threadblock-swizzling functions mapping blockIdx to - Convolution problems. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/platform/platform.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// -CUTLASS_HOST_DEVICE -static int get_strided_dgrad_tile_m( - cutlass::conv::Conv2dProblemSize const &problem_size, - int tile_size_m) { - - // CTAs in M dimension per starting filter position - int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); - - // Inflate number of CTAs in M dimension to cover every strating filter position even those that - // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) - // and point-wise fusion - int tile_m = tile_m_per_filter * int(problem_size.stride().product()); - - // There is a possible performance optimization here that leads up to 2x speeds than the current - // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) - // - // * Optimization * - // Only launch CTAs in M dimension which contribute to a row in Dx output - // - // - // * Constraints * - // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: - // - (A.1): There are no constraints for this case and the optimization does - // affect this case functionality or performance. - // (B) stride > filter, for example, stride={2x2} and filter={1x1}: - // - (B.1): Dx output tensor should be zero initialized - // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero - - return tile_m; -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Threadblock swizzling function for strided dgrad convolution -struct StridedDgradHorizontalThreadblockSwizzle : - public gemm::threadblock::GemmHorizontalThreadblockSwizzle { - - using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; - - CUTLASS_HOST_DEVICE - StridedDgradHorizontalThreadblockSwizzle() { } - - /// Returns the shape of the problem in units of logical tiles - /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) - CUTLASS_HOST_DEVICE - static gemm::GemmCoord get_tiled_shape( - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem_size, - gemm::GemmCoord tile_size, - int split_k_slices) { - - gemm::GemmCoord implicit_gemm_problem_size = - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); - - // compute number of tiles in m dimension - int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - - // compute number of tiles in n dimension - int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); - - return gemm::GemmCoord( - tile_m, - tile_n, - split_k_slices); - } - - /// Returns the shape of the problem in units of logical tiles - /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) - private: - using Base::get_tiled_shape; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Threadblock swizzling function for strided dgrad convolution -template -struct StridedDgradIdentityThreadblockSwizzle : - public gemm::threadblock::GemmIdentityThreadblockSwizzle { - - using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; - - CUTLASS_HOST_DEVICE - StridedDgradIdentityThreadblockSwizzle() { } - - /// Returns the shape of the problem in units of logical tiles - /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) - CUTLASS_HOST_DEVICE - static gemm::GemmCoord get_tiled_shape( - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem_size, - gemm::GemmCoord tile_size, - int split_k_slices) { - - gemm::GemmCoord implicit_gemm_problem_size = - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); - - // compute number of tiles in m dimension - int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - - // compute number of tiles in n dimension - int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); - - return gemm::GemmCoord( - tile_m, - tile_n, - split_k_slices); - } - - /// Returns the shape of the problem in units of logical tiles - /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) - private: - using Base::get_tiled_shape; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for GEMMs -template -struct DepthwiseDirect2dConvIdentityThreadblockSwizzle - : public gemm::threadblock::GemmIdentityThreadblockSwizzle { - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem_size, - gemm::GemmCoord tile_size, - int split_k_slices) { - - gemm::GemmCoord implicit_gemm_problem_size = - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); - - return gemm::GemmCoord(1, - (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), - split_k_slices); - } -}; - -} // namespace threadblock -} // namespace conv -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h deleted file mode 100644 index b7af2e37bd610a12f334943902395b6956362589..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h +++ /dev/null @@ -1,380 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/thread/mma.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/thread/depthwise_mma.h" - - -#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -#include "cutlass/gemm/warp/mma_simt_policy.h" - -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaDepthwiseSimt - : public cutlass::gemm::warp:: - MmaSimt { - using Base = cutlass::gemm::warp:: - MmaSimt; - -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassSimt; - - /// Hard-coded for now - using ArchTag = arch::Sm50; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - -public: - - /// Iterates over the B operand in memory - using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< - MatrixShape, - cutlass::gemm::Operand::kB, - ElementB, - LayoutB, - Policy, - PartitionsK, - Shape::kK - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentB = FragmentB; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaDepthwiseSimt():Base() {} -}; - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Shape of filter shape per threadblock - concept: gemm::GemmShape - typename FilterShape_, - /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> - typename ThreadOutputShape_, - /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> - typename ThreadBlockOutputShape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Iterator algo type - conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, - /// Stride ( MatrixShape ) - typename StrideShape_ = cutlass::MatrixShape<-1, -1>, - /// Dilation ( MatrixShape ) - typename DilationShape_ = cutlass::MatrixShape<-1, -1>, - /// Activation Shape loaded by threadblock - typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaDepthwiseDirectConvSimt { - public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Shape of filter shape per threadblock - concept: gemm::GemmShape - using FilterShape = FilterShape_; - - /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> - using ThreadOutputShape = ThreadOutputShape_; - - /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Iterator algo type - static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; - - /// Stride ( MatrixShape ) - using StrideShape = StrideShape_; - - /// Dilation ( MatrixShape ) - using DilationShape = DilationShape_; - - /// Activation Shape loaded by threadblock - using ActivationShape = ActivationShape_; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassSimt; - - /// Hard-coded for now - using ArchTag = arch::Sm50; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || - platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && - platform::is_same< ElementA, int8_t >::value && - platform::is_same< ElementB, int8_t >::value; - - using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; - - /// Thread-level matrix multiply accumulate operator - using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< - cutlass::gemm::GemmShape< - Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread - Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread - 1>, - ElementA, - ElementB, - ElementC, - arch::OpMultiplyAdd, - dp4a_type - >; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Shape of the underlying instruction - using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; - -public: - - /// Iterates over the A operand in memory - using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< - MatrixShape, // per warp - FilterShape, - ThreadOutputShape, - ThreadBlockOutputShape, - cutlass::gemm::Operand::kA, - ElementA, - Policy, - IteratorAlgorithm, - StrideShape, - DilationShape, - ActivationShape, - PartitionsK, - Shape::kK - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< - MatrixShape<1, Shape::kN>, - cutlass::gemm::Operand::kB, - ElementB, - LayoutB, - Policy, - PartitionsK, - Shape::kK - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentB = FragmentB; - - /// Iterates over the C operand in memory - using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< - MatrixShape, - cutlass::gemm::Operand::kC, - ElementC, - LayoutC, - Policy - >; - - /// Storage for C tile - using FragmentC = typename ThreadMma::FragmentC; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaDepthwiseDirectConvSimt() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &d, - FragmentA a, - FragmentB b, - FragmentC const &c, int group_idx = 0) const { - - ThreadMma mma; - - mma(d, a, b, c); - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace conv -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h deleted file mode 100644 index 47fd1e08b9ff9f693b462fd89f5230475d918120..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h +++ /dev/null @@ -1,862 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT - instructions -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/conv/convolution.h" - -#include "cutlass/arch/memory_sm75.h" - -#include "cutlass/layout/matrix.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -/// -/// concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - cutlass::gemm::Operand Operand, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK = 1, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize = 1 -> -class DepthwiseMmaSimtTileIterator; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for B operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize> -class DepthwiseMmaSimtTileIterator - : public cutlass::gemm::warp::MmaSimtTileIterator { - - using Base = cutlass::gemm::warp::MmaSimtTileIterator; - public: - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = typename Base::TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Thread-level shape of a fragment - using ThreadShape = typename Base::ThreadShape; - - /// Number of individual loads - using Iterations = typename Base::Iterations; - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); - -private: - - MatrixCoord lane_offset_; - int channel_idx_; - int base_channel_idx_; - int warps_n_; - - public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - DepthwiseMmaSimtTileIterator():Base() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - DepthwiseMmaSimtTileIterator( - TensorRef ref, - int lane_id - ) : Base(ref, lane_id) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - warps_n_ = -1; - channel_idx_ = 0; - base_channel_idx_ = 0; - lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - if(warps_n_ == -1){ - warps_n_ = coord.column(); - } - - Base::add_tile_offset(coord); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - void const *ptr = this->ref_.data() + - this->ref_.offset({-(channel_idx_ - base_channel_idx_), - n * Policy::WarpShape::kColumn}) + - pointer_offset / Policy::LaneMmaShape::kN; - - // Base_k of a warp + Base_k of current threads. - int thread_k_base_idx = - warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); - - if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { - // Depthwise kernel would only do computation when channel == k. - // Loads an element when the current computation channel == the k corresponding to this thread. - arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); - } else { - // Reduce SMEM load - dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - if(k_group % PartitionGroupSize == 0 && k_group != 0){ - base_channel_idx_ = k_group; - } - channel_idx_ = k_group; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Size of filter (concept: gemm::GemmShape) - typename FilterShape_, - /// Size of the matrix to load (concept: MatrixShape) - typename ThreadOutputShape_, - /// Size of the matrix to load (concept: MatrixShape) - typename ThreadBlockOutputShape_, - /// Operand identity - cutlass::gemm::Operand Operand, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Iterator algo type - conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - /// Stride ( MatrixShape ) - typename StrideShape = cutlass::MatrixShape<-1, -1>, - /// Dilation ( MatrixShape ) - typename DilationShape = cutlass::MatrixShape<-1, -1>, - /// Activation Shape loaded by threadblock - typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK = 1, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize = 1> -class DepthwiseDirect2dConvSimtTileIterator; - - -/// Specialization for A operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Size of filter (concept: gemm::GemmShape) - typename FilterShape_, - /// Size of the matrix to load (concept: TensorNHWC) - typename ThreadOutputShape_, - /// Size of the matrix to load (concept: TensorNHWC) - typename ThreadBlockOutputShape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Iterator algo type - conv::IteratorAlgorithm IteratorAlgorithm, - /// Stride ( MatrixShape ) - typename StrideShape, - /// Dilation ( MatrixShape ) - typename DilationShape, - /// Activation Shape loaded by threadblock - typename ActivationShape, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize> -class DepthwiseDirect2dConvSimtTileIterator { - public: - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Shape of filter (concept: gemm::GemmShape) - using FilterShape = FilterShape_; - - /// Shape of tile to load (concept: TensorNHWC) - using ThreadOutputShape = ThreadOutputShape_; - - /// Shape of tile to load (concept: TensorNHWC) - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - - /// Operand tag - static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kRow % Policy::WarpShape::kRow), - "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - -// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - ThreadOutputShape::kNHW, // Output tile shape Computed by current threads - ThreadOutputShape::kC - >; - - static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - using ThreadTileCount = MatrixShape< - ThreadBlockOutputShape::kH / ThreadOutputShape::kH, - ThreadBlockOutputShape::kW / ThreadOutputShape::kW - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -protected: - - /// Internal reference - cutlass::TensorRef, layout::RowMajor> ref_; - - int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; - int iterator_r_; - int iterator_s_; - int iterator_offset_; - - int inc_next_s_ ; - int inc_next_r_ ; - - MatrixCoord lane_offset_; -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator( - TensorRef ref, - int lane_id - ) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - // Set channel offset - lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); - - ref.add_coord_offset(lane_offset_); - - ref_.reset(reinterpret_cast *>(ref.data()), - ref.stride(0) / Policy::LaneMmaShape::kN); - - iterator_r_ = 0; - iterator_s_ = 0; - iterator_offset_ = 0; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - template - CUTLASS_HOST_DEVICE - void setup_initial_status(Params const& params) { - - inc_next_s_ = params.inc_next[0]; - inc_next_r_ = params.inc_next[1]; - - // Get base HW offset of current threads - int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); - int base_p_ = - (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; - int base_q_ = - (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < ThreadOutputShape::kH; ++p) { - CUTLASS_PRAGMA_UNROLL - for (int q = 0; q < ThreadOutputShape::kW; ++q) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < Iterations::kColumn; ++col) { - int base_w = (base_q_ + q) * params.stride[0]; - int base_h = (base_p_ + p) * params.stride[1]; - - int offset = base_h * params.activation_tile_w + base_w; - activation_offset[p][q][col] = offset; - } - } - } - } - - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - // Set warp row and col start - lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - void advance(int32_t pointer_offset) { - ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); - iterator_s_ = 0; - iterator_r_ = 0; - iterator_offset_ = 0; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &operator++() { - ++iterator_s_; - if (iterator_s_ < FilterShape::kColumn) { - iterator_offset_ += inc_next_s_; - - return *this; - } - - iterator_s_ = 0; - - ++iterator_r_; - if (iterator_r_ < FilterShape::kRow) { - iterator_offset_ += inc_next_r_; - return *this; - } - - iterator_r_ = 0; - iterator_offset_ = 0; - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator & operator--() { - // Do nothing - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - Array *dst_ptr = - reinterpret_cast *>(&frag); - - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < ThreadOutputShape::kH; ++p) { - CUTLASS_PRAGMA_UNROLL - for (int q = 0; q < ThreadOutputShape::kW; ++q) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - void const *ptr = ref_.data() + - ref_.offset({activation_offset[p][q][n] + (iterator_offset_), - n * Policy::WarpShape::kColumn}) + - pointer_offset / Policy::LaneMmaShape::kN; - arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - // Do nothing at present. - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, Index pointer_offset) const { - store_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/// Specialization for A operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Size of filter (concept: gemm::GemmShape) - typename FilterShape_, - /// Size of the matrix to load (concept: TensorNHWC) - typename ThreadOutputShape_, - /// Size of the matrix to load (concept: TensorNHWC) - typename ThreadBlockOutputShape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Stride ( MatrixShape ) - typename StrideShape_, - /// Dilation ( MatrixShape ) - typename DilationShape_, - /// Activation Shape loaded by threadblock - typename ActivationShape_, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize> -class DepthwiseDirect2dConvSimtTileIterator { - public: - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Shape of filter (concept: gemm::GemmShape) - using FilterShape = FilterShape_; - - /// Shape of tile to load (concept: TensorNHWC) - using ThreadOutputShape = ThreadOutputShape_; - - /// Shape of tile to load (concept: TensorNHWC) - using ThreadBlockOutputShape = ThreadBlockOutputShape_; - - /// Stride ( MatrixShape ) - using StrideShape = StrideShape_; - - /// Dilation ( MatrixShape ) - using DilationShape = DilationShape_; - - /// Activation Shape loaded by threadblock - using ActivationShape = ActivationShape_; - - /// Operand tag - static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kRow % Policy::WarpShape::kRow), - "The warp-level GEMM M size must be divisible by the number of threads arranged " - "along the M dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, - "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - - // Activations loaded by threadblock - static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + - (FilterShape::kRow - 1) * DilationShape::kRow + 1; - - static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + - (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; - - using ThreadActivationShape = cutlass::conv:: - TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; - - // Thread-level shape of a fragment - using ThreadShape = - MatrixShape; - - static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = - MatrixShape; - - using ThreadTileCount = MatrixShape; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - - protected: - /// Internal reference - cutlass::TensorRef, layout::RowMajor> ref_; - - Array - activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; - int iterator_r_; - int iterator_s_; - - - MatrixCoord lane_offset_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - // Set channel offset - lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); - - ref.add_coord_offset(lane_offset_); - - ref_.reset(reinterpret_cast *>(ref.data()), - ref.stride(0) / Policy::LaneMmaShape::kN); - - iterator_r_ = 0; - iterator_s_ = 0; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - template - CUTLASS_HOST_DEVICE void setup_initial_status( - Params const ¶ms) { - - // Get base HW offset of current threads - int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); - int base_h = - (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; - int base_w = - (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; - - CUTLASS_PRAGMA_UNROLL - for (int h = 0; h < ThreadActivationShape::kH; ++h) { - CUTLASS_PRAGMA_UNROLL - for (int w = 0; w < ThreadActivationShape::kW; ++w) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < Iterations::kColumn; ++col) { - int offset = (base_h + h) * ActivationShape::kW + (base_w + w); - - void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); - arch::shared_load(activation[h][w][col], ptr); - } - } - } - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - // Set warp row and col start - lane_offset_ = - MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - void advance(int32_t pointer_offset) { - ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); - iterator_s_ = 0; - iterator_r_ = 0; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &operator++() { - ++iterator_s_; - if (iterator_s_ < FilterShape::kColumn) { - return *this; - } - - iterator_s_ = 0; - - ++iterator_r_; - if (iterator_r_ < FilterShape::kRow) { - return *this; - } - - iterator_r_ = 0; - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - DepthwiseDirect2dConvSimtTileIterator &operator--() { - // Do nothing - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < ThreadOutputShape::kH; ++p) { - CUTLASS_PRAGMA_UNROLL - for (int q = 0; q < ThreadOutputShape::kW; ++q) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; - const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; - - dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - // Do nothing at present. - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, Index pointer_offset) const { - store_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -} // namespace warp -} // namespace conv -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h deleted file mode 100644 index 6cb3935a7e070f0dc34b1ec9c31d9ac448d43b8b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h +++ /dev/null @@ -1,221 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level per channel scale+bias+relu before - matrix multiply-accumulate operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace conv { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FpropScaleBiasReluTransform { - - using T = typename FragmentActivations::Element; - - static int const NumActivations = FragmentActivations::kElements; - static int const NumScaleBias = FragmentScaleBias::kElements; - static int const MmaElements = 2; - // One element has one scale and one bias - static int const MmaScaleBiasPair = 2; - // 16816 has 2 columns - static int const MmaCols = 2; - - using MmaOperand = Array; - using ScaleBiasOperand = Array; - - CUTLASS_DEVICE - void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t *ptr_activations = reinterpret_cast(&activations); - uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); - - // Apply per channel scale+bias+relu if the data is not a special NaN - // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. - - // We assumes the pair of FP16 are either both inbound or both out-of-bound. - // It requires C to be an even number. - asm volatile( - "{\n\t" - " .reg .pred %%p;\n\t" - " .reg .b32 t1;\n\t" - " setp.eq.u32 %%p, %2, %4;\n\t" - " fma.rn.f16x2.relu t1, %1, %2, %3;\n" - " selp.u32 %0, 0, t1, %%p;\n\t" - "}\n" - : "=r"(ptr_activations[0]) - : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), - "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); -#else - assert(0); -#endif - } - - CUTLASS_DEVICE - void operator()(FragmentActivations &activations, - FragmentScaleBias const &scale_bias) { - MmaOperand *ptr_activations = reinterpret_cast(&activations); - ScaleBiasOperand const *ptr_scale_bias = - reinterpret_cast(&scale_bias); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < (NumActivations / MmaElements); ++i) { - transform(ptr_activations[i], ptr_scale_bias[(i / MmaScaleBiasPair) % MmaCols]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct WgradScaleBiasReluTransform { - - using T = typename FragmentActivations::Element; - - static int const NumActivations = FragmentActivations::kElements; - static int const NumScaleBias = FragmentScaleBias::kElements; - static int const MmaElements = 2; - // One element has one scale and one bias - static int const MmaScaleBiasPair = 2; - // 16816 has 2 rows - static int const MmaRows = 2; - - using MmaOperand = Array; - using ScaleBiasOperand = Array<__half2, MmaScaleBiasPair>; - - CUTLASS_DEVICE - void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - __half2 *ptr_activations = reinterpret_cast<__half2 *>(&activations); - uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); - -#if 1 - // CUDA + PTX version - - bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); - bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); - - // Apply per channel scale+bias+relu if the data is not a special NaN - // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. - - // We cannot gurantee that the pair of F16 are both in bound or both - // out-of-bound because C x R x S can be an odd number. - asm volatile( - "{\n\t" - " fma.rn.f16x2.relu %0, %1, %2, %3;\n" - "}" - : "=r"(reinterpret_cast(ptr_activations[0])) - : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), - "r"(ptr_scale_bias[1])); - - reinterpret_cast(ptr_activations[0]) = h1_oob ? - (reinterpret_cast(ptr_activations[0]) & 0xffff0000) : - reinterpret_cast(ptr_activations[0]); - - reinterpret_cast(ptr_activations[0]) = h2_oob ? - (reinterpret_cast(ptr_activations[0]) & 0xffff) : - reinterpret_cast(ptr_activations[0]); -#else - // pure PTX version - - // Apply per channel scale+bias+relu if the data is not a special NaN - // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. - asm volatile( - "{\n" - " .reg .b16 t1, t2;\n" - " .reg .b32 t3, t4, t5, t6;\n" - " .reg .pred p1, p2;\n" - " mov.b32 {t1, t2}, %2;\n" - " setp.eq.s16 p1, t1, %4;\n" - " setp.eq.s16 p2, t2, %4;\n" - " fma.rn.f16x2.relu t3, %1, %2, %3;\n" - " and.b32 t4, t3, %5;\n" - " selp.b32 t5, t4, t3, p1;\n" - " and.b32 t6, t5, %6;\n" - " selp.b32 %0, t6, t5, p2;\n" - "}\n" - : "=r"(reinterpret_cast(ptr_activations[0])) - : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), - "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); -#endif -#else - assert(0); -#endif - } - - CUTLASS_DEVICE - void operator()(FragmentActivations &activations, - FragmentScaleBias const &scale_bias) { - MmaOperand *ptr_activations = reinterpret_cast(&activations); - ScaleBiasOperand const *ptr_scale_bias = - reinterpret_cast(&scale_bias); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < (NumActivations / MmaElements); ++i) { - transform(ptr_activations[i], ptr_scale_bias[(i / MmaRows)]); - } - } -}; -} // namespace warp -} // namespace conv -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h deleted file mode 100644 index 16cfa1b322f24f3e1c64f14b91dd880798e3b68d..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h +++ /dev/null @@ -1,478 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix -*/ - -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#endif - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically-sized array specifying Coords within a tensor -template < - int Rank_, ///< Logical rank of coordinate - typename Index_ = int, ///< Index type used for each dimension - typename LongIndex_ = int64_t ///< Long index type used for linear offsets -> -struct Coord { - -public: - - // - // Type and constant definitions - // - - /// Number of elements in Coord - static int const kRank = Rank_; - - /// Index type used to store elements - using Index = Index_; - - /// Type used to represent linear offsets - using LongIndex = LongIndex_; - -private: - - // - // Data members - // - - /// Indices - Index idx[kRank]; - -public: - - // - // Methods - // - - /// Default ctor initializes uniformly - CUTLASS_HOST_DEVICE - explicit Coord(Index value = Index(0)) { - for (int i = 0; i < kRank; ++i) { - idx[i] = value; - } - } - - /// Constructs from an array of integers - CUTLASS_HOST_DEVICE - Coord(Index const (&_idx)[kRank]) { - for (int i = 0; i < kRank; ++i) { - idx[i] = _idx[i]; - } - } - - /// Constructs from some other Coord - template - CUTLASS_HOST_DEVICE - Coord(Coord other) { - for (int i = 0; i < kRank; ++i) { - idx[i] = other[i]; - } - } - - /// Returns a slice of the Coord which may be larger or smaller in rank - /// than this. - template - CUTLASS_HOST_DEVICE - Coord slice(int start = 0, Index identity = 0) const { - Coord result; - for (int i = 0; i < Slice; ++i) { - if (i + start < kRank) { - result[i] = idx[i + start]; - } - else { - result[i] = identity; - } - } - return result; - } - - /// Returns the index of the dimension with least value - CUTLASS_HOST_DEVICE - int min_dim_index() const { - int i = 0; - for (int j = 1; j < kRank; ++j) { - if (idx[j] < idx[i]) { - i = j; - } - } - return i; - } - - /// Returns the index of the dimension with greatest value - CUTLASS_HOST_DEVICE - int max_dim_index() const { - int i = 0; - for (int j = 1; j < kRank; ++j) { - if (idx[j] > idx[i]) { - i = j; - } - } - return i; - } - - /// Returns true if Coord is non-zero. - CUTLASS_HOST_DEVICE - explicit operator bool() const { - for (int i = 0; i < kRank; ++i) { - if (idx[i]) { - return true; - } - } - return false; - } - - /// Returns true if Coord is uniformly zero. - CUTLASS_HOST_DEVICE - bool operator!() const { - for (int i = 0; i < kRank; ++i) { - if (idx[i]) { - return false; - } - } - return true; - } - - /// Element-wise addition - CUTLASS_HOST_DEVICE - Coord operator+(Coord const& b) const { - Coord c; - for (int i = 0; i < kRank; ++i) { - c.idx[i] = idx[i] + b.idx[i]; - } - return c; - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - Coord operator-(Coord const& b) const { - Coord c; - for (int i = 0; i < kRank; ++i) { - c.idx[i] = idx[i] - b.idx[i]; - } - return c; - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - Coord operator*(Coord const& b) const { - Coord c; - for (int i = 0; i < kRank; ++i) { - c.idx[i] = idx[i] * b.idx[i]; - } - return c; - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - Coord operator/(Coord const& b) const { - Coord c; - for (int i = 0; i < kRank; ++i) { - c.idx[i] = idx[i] / b.idx[i]; - } - return c; - } - - /// In-place addition - CUTLASS_HOST_DEVICE - Coord& operator+=(Coord const& b) { - for (int i = 0; i < kRank; ++i) { - idx[i] += b.idx[i]; - } - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - Coord& operator-=(Coord const& b) { - for (int i = 0; i < kRank; ++i) { - idx[i] -= b.idx[i]; - } - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - Coord& operator*=(Coord const& b) { - for (int i = 0; i < kRank; ++i) { - idx[i] *= b.idx[i]; - } - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - Coord& operator/=(Coord const& b) { - for (int i = 0; i < kRank; ++i) { - idx[i] /= b.idx[i]; - } - return *this; - } - - /// Member access operator - CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } - - /// Member access operator - CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } - - /// Computes the dot product with anotherCoord object - CUTLASS_HOST_DEVICE - LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const { - for (int i = 0; i < kRank; ++i) { - sum += idx[i] * b.idx[i]; - } - return sum; - } - - /// Gets the index of a given Coord element - template - CUTLASS_HOST_DEVICE Index& at() { - return idx[Dim]; - } - - /// Access via index; may limit unrolling potential - CUTLASS_HOST_DEVICE - Index& at(int dim) { return idx[dim]; } - - /// Gets the index of a given Coord element - template - CUTLASS_HOST_DEVICE Index const& at() const { - return idx[Dim]; - } - - /// Access via index; may limit unrolling potential - CUTLASS_HOST_DEVICE - Index const& at(int dim) const { return idx[dim]; } - - /// Determines if two Coord<> objects are equal - CUTLASS_HOST_DEVICE - bool operator==(Coord const& b) const { - bool equal = true; - for (int i = 0; equal && i < kRank; ++i) { - equal = (idx[i] == b.idx[i]); - } - return equal; - } - - /// Not equal - CUTLASS_HOST_DEVICE - bool operator!=(Coord const& b) const { return !(*this == b); } - - /// Clamps a coordinate to a range specified by maximum and minimum values - CUTLASS_HOST_DEVICE - Coord& clamp(Coord const& max, Coord const& min = Coord()) { - for (int i = 0; i < kRank; ++i) { - idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); - } - return *this; - } - - /// Returns the sum of all elements - CUTLASS_HOST_DEVICE - Index sum() const { - Index sum_(idx[0]); - for (int i = 1; i < kRank; ++i) { - sum_ += idx[i]; - } - return sum_; - } - - /// Returns the product of all elements - CUTLASS_HOST_DEVICE - LongIndex product() const { - LongIndex product_(idx[0]); - for (int i = 1; i < kRank; ++i) { - product_ *= idx[i]; - } - return product_; - } - - /// Less than operator - CUTLASS_HOST_DEVICE - bool operator<(Coord const &b) const { - for (int i = 0; i < kRank; ++i) { - if (!(idx[i] < b[i])) { - return false; - } - } - return true; - } - - /// Less than or equals operator - CUTLASS_HOST_DEVICE - bool operator<=(Coord const &b) const { - for (int i = 0; i < kRank; ++i) { - if (!(idx[i] <= b[i])) { - return false; - } - } - return true; - } - - /// Greater than operator - CUTLASS_HOST_DEVICE - bool operator>(Coord const &b) const { - return !(*this <= b); - } - - /// Greater than or equals operator - CUTLASS_HOST_DEVICE - bool operator>=(Coord const &b) const { - return !(*this < b); - } -}; - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - - -/// Scalar multiplication -template -CUTLASS_HOST_DEVICE -Coord operator*(Index s, Coord coord) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - coord[i] *= s; - } - return coord; -} - -/// Scalar multiplication -template -CUTLASS_HOST_DEVICE -Coord operator*(Coord coord, Index s) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - coord[i] *= s; - } - return coord; -} - -/// Scalar division -template -CUTLASS_HOST_DEVICE -Coord operator/(Index s, Coord coord) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - coord[i] = s / coord[i]; - } - return coord; -} - -/// Scalar division -template -CUTLASS_HOST_DEVICE -Coord operator/(Coord coord, Index s) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - coord[i] /= s; - } - return coord; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Integer-valued make_Coord -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to make a 1-element coordinate -template -CUTLASS_HOST_DEVICE -Coord<1, T> make_Coord(T _0) { - T values[1] = {_0}; - return Coord<1, T>(values); -} - -/// Helper to make a 2-element coordinate -template -CUTLASS_HOST_DEVICE -Coord<2, T> make_Coord(T _0, T _1) { - T values[2] = {_0, _1}; - return Coord<2, T>(values); -} - -/// Helper to make a 3-element coordinate -template -CUTLASS_HOST_DEVICE -Coord<3, T> make_Coord(T _0, T _1, T _2) { - T values[3] = {_0, _1, _2}; - return Coord<3, T>(values); -} - -/// Helper to make a 4-element coordinate -template -CUTLASS_HOST_DEVICE -Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { - T values[4] = {_0, _1, _2, _3}; - return Coord<4, T>(values); -} - -/// Helper to make a 5-element coordinate -template -CUTLASS_HOST_DEVICE -Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { - T values[5] = {_0, _1, _2, _3, _4}; - return Coord<5, T>(values); -} - -/// Helper to make a 1-element coordinate -template -CUTLASS_HOST_DEVICE -Coordmake_Coord_with_padding(T _0) { - Coord coord; - - CUTLASS_PRAGMA_UNROLL - for (int i = N - 1; i > 0; --i) { - coord[i] = 0; - } - - coord[0] = _0; - - return coord; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h deleted file mode 100644 index 046b3063a8ca7e1b79248ddea8d10af239eb4bdb..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h +++ /dev/null @@ -1,328 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Helpers for printing cutlass/core objects -*/ -#pragma once - -#include -#include - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix.h" -#include "cutlass/quaternion.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm_enumerated_types.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Output operator for CUDA built-in dim3 type -inline std::ostream &operator<<(std::ostream &out, dim3 d) { - return out << d.x << ", " << d.y << ", " << d.z; -} - -/// Output operator for CUDA built-in error type -inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { - return out << cudaGetErrorString(error); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// stream operators for cutlass namespace // -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline -std::ostream& operator<<(std::ostream& out, Array const& v) { - for (int i = 0; i < Rank; ++i) { - out << (i ? ", " : "") << v[i]; - } - return out; -} - -template -inline -std::ostream& operator<<(std::ostream& out, Coord const& coord) { - for (int i = 0; i < Rank; ++i) { - out << (i ? ", " : "") << coord[i]; - } - return out; -} - -inline -std::istream & operator>>(std::istream &stream, half_t &x) { - float tmp; - stream >> tmp; - x = static_cast(tmp); - return stream; -} - -inline -std::ostream & operator<<(std::ostream &out, half_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { - return out << float(x); -} - - -inline -std::ostream & operator<<(std::ostream &out, float_e2m1_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, detail::float_e2m1_unpacksmem_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, float_e3m2_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, float_e2m3_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, detail::float_e3m2_unpacksmem_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, detail::float_e2m3_unpacksmem_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, float_ue8m0_t const &x) { - return out << float(x); -} - -inline -std::ostream & operator<<(std::ostream &out, float_ue4m3_t const &x) { - return out << float(x); -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to enable formatted printing of CUTLASS scalar types to an ostream -template -struct ScalarIO { - - /// Value to print - T value; - - /// Default ctor - ScalarIO() { } - - /// Constructs from a value - ScalarIO(T value): value(value) {} -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Default printing to ostream -template -inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { - return out << scalar.value; -} - -/// Printing to ostream of int8_t as integer rather than character -template <> -inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { - return out << int(scalar.value); -} - -/// Printing to ostream of uint8_t as integer rather than character -template <> -inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { - return out << unsigned(scalar.value); -} - - -/// Default printing to ostream for MatrixShape -template -inline -std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { - out << "cutlass::MatrixShape::(kRow, kColumn) {" - << cutlass::MatrixShape::kRow <<"," - << cutlass::MatrixShape::kColumn <<"}"; - return out; -} - - -/// Prints matrix to ostream -template -std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { - - for (int i = 0; i < Rows; ++i) { - for (int j = 0; j < Columns; ++j) { - ScalarIO element(rhs.at(i, j)); - out << (j ? ", " : "") << element; - } - out << "\\n"; - } - - return out; -} - -template -std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { - - out << ScalarIO(rhs.w()) << " "; - if (rhs.x() >= 0) { - out << "+"; - } - - out << ScalarIO(rhs.x()) << "*i "; - if (rhs.y() >= 0) { - out << "+"; - } - - out << ScalarIO(rhs.y()) << "*j "; - if (rhs.z() >= 0) { - out << "+"; - } - - out << ScalarIO(rhs.z()) << "*k"; - - return out; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// stream operators for cutlass::gemm namespace // -/////////////////////////////////////////////////////////////////////////////////////////////////// -namespace gemm { - -/// Default printing to ostream for GemmShape -template -inline -std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { - out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" - << cutlass::gemm::GemmShape::kM <<"," - << cutlass::gemm::GemmShape::kN <<"," - << cutlass::gemm::GemmShape::kK << "}"; - return out; -} - -/// Default printing to ostream for GemmCoord -inline -std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { - out << "cutlass::gemm::GemmCoord {" - << gemm_coord.m() <<"," - << gemm_coord.n() <<"," - << gemm_coord.k() << "}"; - return out; -} - -} //namespace gemm -/////////////////////////////////////////////////////////////////////////////////////////////////// - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// stream operators for cutlass namespace // -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Default printing to ostream for PitchLinearShape -template < int Contiguous, int Strided> -inline -std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { - out << "cutlass::PitchLinearShape:(kContiguous, kStrided) {" - << cutlass::layout::PitchLinearShape::kContiguous <<"," - << cutlass::layout::PitchLinearShape::kStrided <<"}"; - return out; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// stream operators for cutlass::conv namespace // -/////////////////////////////////////////////////////////////////////////////////////////////////// -namespace conv { -/// Default printing to ostream for Conv2dProblemSize -inline -std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { - out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl - << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl - << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl - << "groups: (" << problem.groups << ")" << std::endl - << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl - << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl - << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl - << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl - << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; - - return out; -} - - -/// Default printing to ostream for Conv3dProblemSize -inline -std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { - out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl - << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl - << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl - << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl - << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl - << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl - << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl - << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; - - return out; -} - -} // namespace conv -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp deleted file mode 100644 index a8af62be2d3e27ccf499acaead03dc3aadd4c151..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp +++ /dev/null @@ -1,428 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Interface between a CUTLASS device-wide operator and CUDA. -*/ - -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/trace.h" - -#include "cutlass/platform/platform.h" -#if ! defined(__CUDACC_RTC__) -#include -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// NVRTC doesn't need definitions for these host classes - -#if ((__CUDACC_VER_MAJOR__ >= 12) || \ - ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) \ - && !defined(__CUDACC_RTC__) -#define CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED -#endif - -#if ((__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)) -#define CUDA_HOST_ADAPTER_TENSORMAP_ENABLED -#endif - -// Include for CUDA Driver API calls if any of these capabilities are enabled. -#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || \ - defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) - -#include - -#endif // defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || - // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Macro-level guard for CUDA Host Adapter -// -#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) -#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -#if !defined(__CUDACC_RTC__) - -#if ((__CUDACC_VER_MAJOR__ >= 12) || \ - ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) -#include -#endif // (__CUDACC_VERSION__ >= 11.8) - -#include - -#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok - -#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) - -#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ - template \ - CUresult call_##func(Args... args) { \ - return func(args...); \ - } - -#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) - -#if (__CUDACC_VER_MAJOR__ > 12) - -#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ - template \ - CUresult call_##func(Args... args) { \ - cudaDriverEntryPointQueryResult cuda_status; \ - void* pfn = nullptr; \ - cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \ - CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ - &pfn, ver, \ - cudaEnableDefault, \ - &cuda_status); \ - if (cuda_status != cudaDriverEntryPointSuccess || \ - cuda_err != cudaSuccess) { \ - return CUDA_ERROR_UNKNOWN; \ - } \ - return reinterpret_cast(pfn)(args...); \ - } - -#else - -#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ - template \ - CUresult call_##func(Args... args) { \ - cudaDriverEntryPointQueryResult cuda_status; \ - void* pfn = nullptr; \ - cudaError_t cuda_err = cudaGetDriverEntryPoint( \ - CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ - &pfn, \ - cudaEnableDefault, \ - &cuda_status); \ - if (cuda_status != cudaDriverEntryPointSuccess || \ - cuda_err != cudaSuccess) { \ - return CUDA_ERROR_UNKNOWN; \ - } \ - return reinterpret_cast(pfn)(args...); \ - } - -#endif // (__CUDACC_VER_MAJOR__ > 12) - -#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) - -#if (__CUDACC_VER_MAJOR__ >= 12) -CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000); -CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); -#endif - -#undef CUTLASS_CUDA_DRIVER_STRINGIFY - -#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func - -#endif // !defined(__CUDACC_RTC__) - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter -/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute -/// is not introduced. -struct CudaHostLaunchAttributes { - -#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) - - /// Reasonable maximum launch attributes that are commonly applied - static constexpr int32_t kMaximumAttributeCount = 5; - - /// Launch attributes - CUlaunchAttribute launch_attributes[kMaximumAttributeCount]; - int32_t attribute_count = 0; - - CUTLASS_HOST_DEVICE - CudaHostLaunchAttributes(CUlaunchAttribute *launch_attributes_ = nullptr, - int32_t attribute_count_ = 0) { - CUTLASS_ASSERT(attribute_count_ >= 0 && attribute_count_ < kMaximumAttributeCount); - for (int32_t i = 0; i < attribute_count_ && i < kMaximumAttributeCount; ++i) { - launch_attributes[i] = launch_attributes_[i]; - } - attribute_count = attribute_count_; - } - - CUTLASS_HOST_DEVICE - CUlaunchAttribute const* data() const { - return launch_attributes; - } - - CUTLASS_HOST_DEVICE - size_t size() const { - return attribute_count; - } - -#endif // (CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) - -}; - - -/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and -/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. -struct CudaHostAdapter { - - /// Limit the number of kernels - static constexpr int32_t kMaximumKernelCount = 4; - - /// Maximum cluster size - static constexpr int MaxClusterSize = 32; - - // - // Data members - // - - /// Handles - void *kernel_handles[kMaximumKernelCount]; - int32_t kernel_count = 0; - - CudaHostLaunchAttributes launch_attributes; - - // - // Methods - // - - /// Ctor - CudaHostAdapter() = default; - - /// Dtor - virtual ~CudaHostAdapter() = default; - - /// Copy Ctor - CUTLASS_HOST_DEVICE - CudaHostAdapter(const CudaHostAdapter & rhs) - : kernel_count(rhs.kernel_count), - launch_attributes(rhs.launch_attributes) { - CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); - - for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { - kernel_handles[i] = rhs.kernel_handles[i]; - } - } - - /// Copy Assignment - CUTLASS_HOST_DEVICE - CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { - CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); - for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { - kernel_handles[i] = rhs.kernel_handles[i]; - } - kernel_count = rhs.kernel_count; - - launch_attributes = rhs.launch_attributes; - - return *this; - } - - - /// Move ctor - CUTLASS_HOST_DEVICE - CudaHostAdapter(CudaHostAdapter && rhs) - : kernel_count(rhs.kernel_count), - launch_attributes(std::move(rhs.launch_attributes)) { - CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); - - for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { - kernel_handles[i] = rhs.kernel_handles[i]; - } - } - - // / Move assignment - CUTLASS_HOST_DEVICE - CudaHostAdapter& operator=(CudaHostAdapter && rhs) { - CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); - for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { - kernel_handles[i] = rhs.kernel_handles[i]; - } - kernel_count = rhs.kernel_count; - launch_attributes = std::move(rhs.launch_attributes); - return *this; - } - - /// Ctor - CUTLASS_HOST_DEVICE - CudaHostAdapter(void **kernel_handles_, - int32_t kernel_count_, - CudaHostLaunchAttributes const &launch_attributes_ = { }) - : kernel_count(kernel_count_), - launch_attributes(launch_attributes_) { - CUTLASS_ASSERT(kernel_count >= 0 && kernel_count < kMaximumKernelCount); - - for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { - kernel_handles[i] = kernel_handles_[i]; - } - } - - /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) - CUTLASS_HOST_DEVICE - bool empty() const { return !kernel_count; } - - /// Returns kernel_count - CUTLASS_HOST_DEVICE - size_t size() const { return static_cast(kernel_count); } - - /// Queries the occupancy of a kernel - virtual Status query_occupancy( - int32_t *device_sms, - int32_t *sm_occupancy, - int32_t kernel_index, - int32_t thread_count, - int32_t smem_size) const = 0; - - /// Launches a kernel without using Threadblock Clusters. - virtual Status launch( - dim3 const grid_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - void** kernel_params, - int32_t kernel_index) const = 0; - - /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. - virtual Status launch( - dim3 const grid_dims, - dim3 const cluster_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - void** kernel_params, - int32_t kernel_index) const = 0; - - - - /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. - /// This API is for preferred cluster launch; a preferred and a fallback cluster shapes are - /// considered for launch respectively. - virtual Status launch( - dim3 const grid_dims, - dim3 const cluster_dims, - dim3 const fallback_cluster_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - void** kernel_params, - int32_t kernel_index) const = 0; - - - -#if defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) - - /// Create a tensor map descriptor object representing im2col memory region. - virtual CUresult tensorMapEncodeIm2col ( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const int* pixelBoxLowerCorner, - const int* pixelBoxUpperCorner, - cuuint32_t channelsPerPixel, - cuuint32_t pixelsPerColumn, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) const = 0; - - /// Create a tensor map descriptor object representing tiled memory region. - virtual CUresult tensorMapEncodeTiled ( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const cuuint32_t* boxDim, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) const = 0; - - /// Modify an existing tensor map descriptor with an updated global address. - virtual CUresult tensorMapReplaceAddress( - CUtensorMap* tensorMap, - void* globalAddress) const = 0; - -#endif // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) - -protected: - - /** - * Fills a buffer in Global Memory with a byte sequence copied from host memory. - * This function can be overridden to dispatch to the appropriate cuMemsetD*Async API - */ - virtual Status memsetDeviceImpl( - void* destination, ///< Device memory pointer to be filled - void const* fill_value, ///< Value to be filled in the buffer - size_t fill_size, ///< Size of the data type to be used for filling the buffer - size_t count, ///< Number of elements of size fill_size - cudaStream_t stream) const = 0; - -public: - - /// Fills a buffer in Global Memory with a byte sequence copied from host memory - template - CUTLASS_HOST_DEVICE - Status memsetDevice( - void* destination, - FillValueType fill_value, - size_t count, - cudaStream_t stream) const { - return this->memsetDeviceImpl( - destination, - &fill_value, - sizeof(FillValueType), - count, - stream); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h deleted file mode 100644 index c68a3ba38cb554278e692d012ca2a93b547e08f1..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h +++ /dev/null @@ -1,165 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Basic include for CUTLASS. -*/ - -#pragma once - -#include "cutlass/detail/helper_macros.hpp" - -#if (__CUDACC_VER_MAJOR__ >= 13) - #define CUDA_STD_HEADER(header) -#else - #define CUDA_STD_HEADER(header) -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/// Status code returned by CUTLASS operations -enum class Status { - kSuccess, ///< Operation was successful. - kErrorMisalignedOperand, ///< operands fail alignment requirements. - kErrorInvalidDataType, ///< DataType fails requirement. - kErrorInvalidLayout, ///< Layout fails alignment requirement. - kErrorInvalidProblem, ///< Specified problem size is not supported by operator. - kErrorNotSupported, ///< Operation is not supported on current device. - kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. - kErrorInternal, ///< An error within CUTLASS occurred. - kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. - kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. - kErrorMemoryAllocation, ///< Kernel launch failed due to insufficient device memory. - kInvalid ///< Status is unspecified. -}; - -/// Convert cutlass status to status strings -CUTLASS_HOST_DEVICE -static char const* cutlassGetStatusString(cutlass::Status status) { - switch (status) { - case cutlass::Status::kSuccess: - return "Success"; - case cutlass::Status::kErrorMisalignedOperand: - return "Error Misaligned Operand"; - case cutlass::Status::kErrorInvalidDataType: - return "Error Invalid Data Type"; - case cutlass::Status::kErrorInvalidLayout: - return "Error Invalid Layout"; - case cutlass::Status::kErrorInvalidProblem: - return "Error Invalid Problem"; - case cutlass::Status::kErrorNotSupported: - return "Error Not Supported"; - case cutlass::Status::kErrorWorkspaceNull: - return "Error Workspace Null"; - case cutlass::Status::kErrorInternal: - return "Error Internal"; - case cutlass::Status::kErrorInsufficientDriver: - return "Error Insufficient Driver"; - case cutlass::Status::kErrorArchMismatch: - return "Error Architecture Mismatch"; - case cutlass::Status::kErrorMemoryAllocation: - return "Error Memory Allocation failed"; - case cutlass::Status::kInvalid: break; - } - - return "Invalid status"; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static const int NumThreadsPerWarp = 32; -static const int NumThreadsPerWarpGroup = 128; -static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; -static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -static const int NumThreadsPerQuad = 4; -static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper function to return true when called by thread 0 of threadblock 0. -CUTLASS_HOST_DEVICE bool thread0() { - #if defined(__CUDA_ARCH__) - return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); - #else - return false; - #endif -} - -/// Returns a lane index in the warp. The threads in warp may not be convergent -CUTLASS_DEVICE -int canonical_lane_idx() { - #if defined(__CUDA_ARCH__) - return threadIdx.x % NumThreadsPerWarp; - #else - return 0; - #endif -} - -/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. -/// Threads within the warp must be converged. -CUTLASS_DEVICE -int canonical_warp_idx_sync() { - #if defined(__CUDA_ARCH__) - return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); - #else - return 0; - #endif -} - -/// Returns a warp index in the CTA. The threads in warp may not be convergent -/// As it doesn't sync the warp, it faster and allows forward progress -CUTLASS_DEVICE -int canonical_warp_idx() { - #if defined(__CUDA_ARCH__) - return threadIdx.x / NumThreadsPerWarp; - #else - return 0; - #endif -} - -/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. -/// Threads within the warp must be converged. -CUTLASS_DEVICE -int canonical_warp_group_idx() { - #if defined(__CUDA_ARCH__) - return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); - #else - return 0; - #endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp deleted file mode 100644 index a304cd6e3adae2c009c5b474a8e2920b618a3ea3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp +++ /dev/null @@ -1,305 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Blockwise Scale configs specific for Blockwise/Groupwise MMA -*/ - -#pragma once - -#include "cutlass/layout/matrix.h" - -#include "cute/int_tuple.hpp" -#include "cute/atom/mma_traits_sm100.hpp" -#include "cute/arch/mma_sm90.hpp" - -namespace cutlass::detail{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -using namespace cute; - -template -struct Sm1xxBlockwiseScaleConfig { - - using ShapeSFA = Shape, int32_t>, Shape, int32_t>, int32_t>; - using ShapeSFB = Shape, int32_t>, Shape, int32_t>, int32_t>; - - using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using LayoutSFA = Layout; - using LayoutSFB = Layout; - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFA() { - return LayoutSFA{}; - } - - template - CUTE_HOST_DEVICE - static constexpr auto - smem_atom_layoutSFA(CtaShape_MNK cta_shape_mnk) { - static_assert(cute::is_static_v, "Expect static CTA shape"); - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [M, N, K] = cta_shape_mnk; - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeM)>{})); - } - else { - return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K] = cta_shape_mnk; - return make_layout( - make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeM)>{}), - make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), - strides - ); - } - - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFB() { - return LayoutSFB{}; - } - - template - CUTE_HOST_DEVICE - static constexpr auto - smem_atom_layoutSFB(CtaShape_MNK cta_shape_mnk) { - static_assert(cute::is_static_v, "Expect static CTA shape"); - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeN)>{})); - } - else { - return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K] = cta_shape_mnk; - return make_layout( - make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeN)>{}), - make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), - strides - ); - } - - // The following function is provided for user fill dynamic problem size to the layout_SFA. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFA(ProblemShape problem_shape) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [M, N, K, L] = problem_shape_MNKL; - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, SFVecSizeM))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K, L] = problem_shape_MNKL; - auto mk_layout = make_layout( - make_shape(make_shape(Int{}, cute::ceil_div(M, SFVecSizeM)), - make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), - strides - ); - - return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); - } - - // The following function is provided for user fill dynamic problem size to the layout_SFB. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFB(ProblemShape problem_shape) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [M, N, K, L] = problem_shape_MNKL; - - if constexpr (majorSFB == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, SFVecSizeN))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K, L] = problem_shape_MNKL; - auto nk_layout = make_layout( - make_shape(make_shape(Int{}, cute::ceil_div(N, SFVecSizeN)), - make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), - strides - ); - - return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); - } - -}; - -template -struct RuntimeBlockwiseScaleConfig { - - using ShapeSFA = Shape, Shape, int32_t>; - using ShapeSFB = Shape, Shape, int32_t>; - - using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using LayoutSFA = Layout; - using LayoutSFB = Layout; - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFA() { - return LayoutSFA{}; - } - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFB() { - return LayoutSFB{}; - } - - // The following function is provided for user fill dynamic problem size to the layout_SFA. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFA(ProblemShape problem_shape, SFVecShape sf_vec_shape) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [M, N, K, L] = problem_shape_MNKL; - auto [sfm, sfn, sfk] = sf_vec_shape; - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, sfm))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K, L] = problem_shape_MNKL; - auto [sfm, sfn, sfk] = sf_vec_shape; - auto mk_layout = make_layout( - make_shape(make_shape(sfm, cute::ceil_div(M, sfm)), - make_shape(sfk, cute::ceil_div(K, sfk))), - strides - ); - - return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); - } - - // The following function is provided for user fill dynamic problem size to the layout_SFB. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFB(ProblemShape problem_shape, SFVecShape sf_vec_shape) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [M, N, K, L] = problem_shape_MNKL; - auto [sfm, sfn, sfk] = sf_vec_shape; - - if constexpr (majorSFB == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, sfn))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); - } - }(); - - auto [M, N, K, L] = problem_shape_MNKL; - auto [sfm, sfn, sfk] = sf_vec_shape; - auto nk_layout = make_layout( - make_shape(make_shape(sfn, cute::ceil_div(N, sfn)), - make_shape(sfk, cute::ceil_div(K, sfk))), - strides - ); - - return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); - } - -}; - -// Sm90 only supports MN major for SFA and SFB for now -template -using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig< - SFVecSizeM, - SFVecSizeN, - SFVecSizeK, - majorSFA == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K, - majorSFB == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K>; - -template -using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; - -template -using Sm120BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; - -template -constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { - return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; -} - -template -constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { - return Sm100BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; -} - -template -constexpr auto sm120_trivial_blockwise_scale_config(MmaTileShape_MNK) { - return Sm120BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp deleted file mode 100644 index d35765adebaa35bfcd767ff245ec72d453c28563..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp +++ /dev/null @@ -1,99 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - - - -#include "cute/container/tuple.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/trace.h" -#include "cute/layout.hpp" // cute::make_shape -#include "cutlass/trace.h" // CUTLASS_TRACE_HOST - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::detail { - -// Returns either ClusterShape, if it is static, or a Shape> populated with the -// x and y dimensions of `dynamic_cluster_shape`. -template -CUTLASS_HOST_DEVICE -static auto -select_cluster_shape(ClusterShape cluster_shape, dim3 dynamic_cluster_shape) { - return cute::conditional_return>( - make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}), - cluster_shape); -} - -template -CUTLASS_DEVICE -static auto -select_cluster_shape(ClusterShape cluster_shape) { - if constexpr (cute::is_static_v) { - return cluster_shape; - } - else { - dim3 dynamic_cluster_shape = cute::cluster_shape(); - return make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}); - } -} - -// Dynamic cluster shape can_implement rule -template -CUTLASS_HOST_DEVICE -bool -preferred_cluster_can_implement(dim3 cluster_shape, dim3 cluster_shape_fallback) { - bool implementable{true}; - - // Runtime cluster shape should satisfy MMA requirements - auto AtomThrShapeM = cute::size<0>(AtomThrShapeMNK{}); - implementable &= (cluster_shape.x > 0 && cluster_shape.y > 0 && cluster_shape.z > 0); - implementable &= (cluster_shape.x % AtomThrShapeM == 0); - - implementable &= (cluster_shape_fallback.x > 0 && cluster_shape_fallback.y > 0 && cluster_shape_fallback.z > 0); - implementable &= (cluster_shape_fallback.x % AtomThrShapeM == 0); - - // Only support pow2 runtime cluster shape for now - implementable &= ispow2(cluster_shape.x) && - ispow2(cluster_shape.y) && - ispow2(cluster_shape.z); - - implementable &= ispow2(cluster_shape_fallback.x) && - ispow2(cluster_shape_fallback.y) && - ispow2(cluster_shape_fallback.z); - - return implementable; -} - -} // namespace cutlass::detail - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp deleted file mode 100644 index 01085c54159fc1cd5d6b7e2ee1d40a46cccd4f67..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp +++ /dev/null @@ -1,191 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/container/tuple.hpp" -#include "cute/layout.hpp" // cute::size(shape) -#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::MXF4Format, cute::UMMA::MXF8F6F4Format -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct deduce_mixed_width_dtype { -static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively."); - -private: - using underlying_tuple = cute::conditional_t::value, Tuple, cute::tuple>; - static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v - 1); - -public: - using type = cute::conditional_t<(I < cute::tuple_size_v), - cute::tuple_element_t, - void>; -}; - -template -using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; - - - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_runtime_f8f6f4() { - return (cute::is_same_v || - cute::is_same_v || - cute::is_same_v); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_f8f6f4_inputs() { - return ( - - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - - cute::is_same_v || - cute::is_same_v - || cute::is_same_v || - cute::is_same_v || - cute::is_same_v - - ) && - ( - - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - - cute::is_same_v || - cute::is_same_v - || cute::is_same_v || - cute::is_same_v || - cute::is_same_v - - ); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm100_mma_f8f6f4() { - return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 32) && is_sm10x_f8f6f4_inputs(); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_f8f6f4_element() { - return (cute::is_same_v - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v - - ); -} - - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_f4_element() { - return (cute::is_same_v - ); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_mxf8f6f4_input() { - // ElementType must be F8, F6, or F4 - return ( cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm10x_mxf4nvf4_input() { - // ElementType must be F4 - return ( cute::is_same_v || - cute::is_same_v - ); -} - -template -struct sm10x_block_scale_runtime_input_t { - static constexpr bool IsF8F6F4MmaInput = is_sm10x_mxf8f6f4_input(); - static constexpr bool IsF4MmaInput = is_sm10x_mxf4nvf4_input(); - - using Type = cute::conditional_t - >; -}; - - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm120_f8f6f4() { - return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 32) && is_sm10x_f8f6f4_inputs(); -} - -template -CUTLASS_HOST_DEVICE -static constexpr bool -is_sm100_sparse_f8f6f4() { - return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 64) && is_sm10x_f8f6f4_inputs(); -} - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp deleted file mode 100644 index 89d250001eaa00990319cc2a1da35ec0dccb8703..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp +++ /dev/null @@ -1,1249 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cute/util/type_traits.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -// The universal converter -template < - class SrcType, - class DstType, - class LayoutIn, - class LayoutOut -> -struct LayoutAwareConvertImpl { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor const& src, - cute::Tensor & dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - static_assert(cute::cosize_v == cute::cosize_v); - constexpr int N = decltype(cute::max_common_vector(LayoutIn{}, LayoutOut{})){}; - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using Converter = cutlass::NumericArrayConverter; - auto&& src_vm = cute::recast(src); - auto&& dst_vm = cute::recast(dst); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < src_vm.size(); ++i) { - dst_vm(i) = Converter::convert(src_vm(i)); - } - } -}; - -// Specialization for INT4 -> BF16 with [02461357] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::int4b_t, - cutlass::bfloat16_t, - cute::Layout, cute::Stride<_4,_1>>, - cute::Layout<_8> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_4,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (size_t ii = 0; ii < RegArray::kElements; ++ii) { - r[ii] = src_reg >> (4 * (ii)); - static constexpr uint32_t xor_mask = 0x43084308; - static constexpr uint32_t lo_mask = 0x000F000F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(lo_mask), "n"(xor_mask), "n"(immLut)); - static constexpr uint32_t lo_bias = xor_mask; // 0x43084308, {136, 136} - { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, - reinterpret_cast(lo_bias)); - } - } - } -}; - -// Specialization for UINT4 -> BF16 with [02461357] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::uint4b_t, - cutlass::bfloat16_t, - cute::Layout, cute::Stride<_4,_1>>, - cute::Layout<_8> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_4,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (size_t ii = 0; ii < RegArray::kElements; ++ii) { - r[ii] = src_reg >> (4 * (ii)); - static constexpr uint32_t or_mask = 0x43004300; - static constexpr uint32_t lo_mask = 0x000F000F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(lo_mask), "n"(or_mask), "n"(immLut)); - static constexpr uint32_t lo_bias = or_mask; // 0x43004300, {128, 128} - { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, - reinterpret_cast(lo_bias)); - } - } - } -}; - -// Specialization for INT4 -> FP16 with [02461357] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::int4b_t, - cutlass::half_t, - cute::Layout, cute::Stride<_4,_1>>, - cute::Layout<_8> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_4,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - auto src_ = src_reg >> (4 * (ii)); - r[ii + 0] = src_; - r[ii + 1] = src_; - static constexpr uint32_t lo_xor_mask = 0x64086408; - static constexpr uint32_t hi_xor_mask = 0x64806480; - static constexpr uint32_t lo_mask = 0x000F000F; - static constexpr uint32_t hi_mask = 0x00F000F0; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii + 0]) - : "n"(lo_mask), "n"(lo_xor_mask), "n"(immLut)); - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii + 1]) - : "n"(hi_mask), "n"(hi_xor_mask), "n"(immLut)); - static constexpr uint32_t lo_bias = 0x64086408; // {1032, 1032} - static constexpr uint32_t hi_bias = 0xD480D480; // {-72, -72} - static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} - { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); - fp16x2_val = __hsub2(fp16x2_val, - reinterpret_cast(lo_bias)); - } - { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); - fp16x2_val = __hfma2(fp16x2_val, - reinterpret_cast(hi_scale), - reinterpret_cast(hi_bias)); - } - } - } -}; - -// Specialization for UINT4 -> FP16 with [02461357] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::uint4b_t, - cutlass::half_t, - cute::Layout, cute::Stride<_4,_1>>, - cute::Layout<_8> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_4,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - auto src_ = src_reg >> (4 * (ii)); - r[ii + 0] = src_; - r[ii + 1] = src_; - static constexpr uint32_t or_mask = 0x64006400; - static constexpr uint32_t lo_mask = 0x000F000F; - static constexpr uint32_t hi_mask = 0x00F000F0; - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(lo_mask), "n"(or_mask), "n"(immLut)); - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii + 1]) - : "n"(hi_mask), "n"(or_mask), "n"(immLut)); - static constexpr uint32_t lo_bias = or_mask; // 0x64006400, {1024, 1024} - static constexpr uint32_t hi_bias = 0xD400D400; // {-64, -64} - static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} - { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); - fp16x2_val = __hsub2(fp16x2_val, - reinterpret_cast(lo_bias)); - } - { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); - fp16x2_val = __hfma2(fp16x2_val, - reinterpret_cast(hi_scale), - reinterpret_cast(hi_bias)); - } - } - } -}; -/* -// Specialization for E5M2 -> FP16 with [3120] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::float_e5m2_t, - cutlass::half_t, - cute::Layout, cute::Stride<_2,_1>>, - cute::Layout<_4> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_2,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - // in registers: a3, a1, a2, a0 - r[RegArray::kElements - ii - 1] = src_reg << (8 * (ii)); - - static constexpr uint32_t and_mask = 0xFF00FF00; - asm volatile( - "{\n" - " and.b32 %0, %0, %1;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask)); - } - } -}; -*/ -// Specialization for INT8 -> BF16 with [3120] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::int8_t, - cutlass::bfloat16_t, - cute::Layout, cute::Stride<_2,_1>>, - cute::Layout<_4> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_2,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - uint32_t tmp0, tmp1; - r[ii] = src_reg >> (8 * (ii)); - static constexpr uint32_t or_mask = 0x43004300; - static constexpr uint32_t and_mask_0 = 0x007F007F; - static constexpr uint32_t and_mask_1 = 0x00800080; - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %1, %2, %3, %4;\n" - "}\n" - : "=r"(tmp0) - : "r"(r[ii]), "n"(and_mask_0), "n"(or_mask), "n"(immLut)); - asm volatile( - "{\n" - " lop3.b32 %0, %1, %2, %3, %4;\n" - "}\n" - : "=r"(tmp1) - : "r"(r[ii]), "n"(and_mask_1), "n"(or_mask), "n"(immLut)); - { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(reinterpret_cast<__nv_bfloat162 const&>(tmp0), - reinterpret_cast<__nv_bfloat162 const&>(tmp1)); - } - } - } -}; - -// Specialization for INT8 -> FP16 with [3120] value order -template <> -struct LayoutAwareConvertImpl< - cutlass::int8_t, - cutlass::half_t, - cute::Layout, cute::Stride<_2,_1>>, - cute::Layout<_4> -> { - template - CUTLASS_DEVICE - static void convert( - cute::Tensor, cute::Stride<_2,_1>> - > const& src, - cute::Tensor - >& dst) { - - static_assert(cute::is_same_v && - cute::is_same_v); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - r[ii] = src_reg >> (8 * (ii)); - static constexpr uint32_t xor_mask = 0x64806480; - static constexpr uint32_t and_mask = 0x00FF00FF; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - { - static constexpr uint32_t bias = 0x64806480; - __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hsub2(fp16x2_val, - reinterpret_cast<__half2 const&>(bias)); - } - } - } -}; - -template < - class EngineIn, - class EngineOut, - class LayoutIn, - class LayoutOut -> -CUTLASS_DEVICE -void LayoutAwareConvert( // Accept mutable temporaries - cute::Tensor const& src, - cute::Tensor && dst) { - - LayoutAwareConvert(src, dst); -} -template < - class EngineIn, - class EngineOut, - class LayoutIn, - class LayoutOut -> -CUTLASS_DEVICE -void LayoutAwareConvert( - cute::Tensor const& src, - cute::Tensor & dst) { - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - Tensor src_vm = coalesce(src); - Tensor dst_vm = coalesce(dst); - Layout src_layout = src_vm.layout(); - Layout dst_layout = dst_vm.layout(); - LayoutAwareConvertImpl::convert(src_vm, dst_vm); -} - - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - namespace detail { - enum class ConversionMode { - DirectConvert, // A * B - ConvertAndScale, // (scale * A) * B - ConvertAndScaleWithZero // (scale * A + zeros) * B - }; - } // namespace detail -} //namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective::detail { - -template -static constexpr -CUTLASS_HOST_DEVICE -auto get_logical_ptr(PointerType const* ptr) { - return cute::recast_ptr(ptr); -} -template -static constexpr -CUTLASS_HOST_DEVICE -auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) { - if constexpr (not cute::is_layout::value) { - return tile_to_shape( - layout_atom, - append(tile_shape, Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}); - } - else { - auto gmem_tile = composition(stride, tile_shape); - return make_layout_like(append(gmem_tile, make_layout(Int{}, 0))); - } -} -template -static constexpr -CUTLASS_HOST_DEVICE -auto get_gmem_layout(Shape const& shape, Stride const& stride) { - if constexpr (not cute::is_layout::value) { - return make_layout(shape, stride); - } - else { - return stride; - } -} - -template -struct MixedInputUtils { -private: - using ConversionMode = cutlass::detail::ConversionMode; - using KernelSchedule = typename Collective::KernelSchedule; - using SmemLayoutA = typename Collective::SmemLayoutA; - using SmemLayoutB = typename Collective::SmemLayoutB; - using SmemLayoutScale = typename Collective::SmemLayoutScale; - using SwappedElementA = typename Collective::SwappedElementA; - using SwappedElementB = typename Collective::SwappedElementB; - using RealSwappedElementA = typename Collective::RealSwappedElementA; - using RealSwappedElementB = typename Collective::RealSwappedElementB; - using ElementScale = typename Collective::ElementScale; - using ElementZero = typename Collective::ElementZero; - using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale; - static constexpr auto KernelConversionMode = Collective::KernelConversionMode; - static constexpr auto ModeHasScales = Collective::ModeHasScales; - static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; - -public: - static constexpr auto - elements_per_smem_scale() { - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return 0; - } - else if constexpr (ModeHasScales) { - return cute::cosize_v; - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); - } - } - - static constexpr auto - elements_per_smem_zero() { - if constexpr (KernelConversionMode == ConversionMode::DirectConvert || - KernelConversionMode == ConversionMode::ConvertAndScale ) { - return 0; - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - return cute::cosize_v; - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); - } - } - - // These methods use some the public members of the class. For that reason, we define them after the public section. - static constexpr uint32_t - compute_tma_transaction_bytes_mk() { - return cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); - } - - static constexpr uint32_t - compute_tma_transaction_bytes_nk() { - return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); - } - - static constexpr uint32_t - compute_tma_transaction_bytes_extra() { - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return 0; - } - else if constexpr (ModeHasScales) { - constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); - static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return scale_tx_bytes; - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - // Scale and zero share smem layout - constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); - static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA - return scale_tx_bytes + zero_tx_bytes; - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); - } - } - - static constexpr uint32_t - compute_tma_transaction_bytes_extra_transform() { - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return 0; - } - else if constexpr (ModeHasScales) { - constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); - static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return scale_tx_bytes; - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - // Scale and zero share smem layout - constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); - static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA - return scale_tx_bytes + zero_tx_bytes; - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); - } - } - - /// Utilities to copy A and extra inputs from smem to RF - template - CUTLASS_DEVICE - static void copy_tensors_MK( - SmemTiledCopyA const& smem_tiled_copy_A, - TensorASmemView const& tCsA, - TensorACopyView& tCrA_copy_view, - cute::tuple const& partitioned_mma_extra_info, - cute::tuple const& tiled_copy_and_views, - int k_block, - int read_stage) { - - copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block)); - - if (k_block == 0) { - // We are starting a new k-tile so copy the scale - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - } - else if constexpr (ModeHasScales) { - auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); - auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); - auto tCsS = cute::get<0>(partitioned_mma_extra_info); - copy(smem_tiled_copy_S, tCsS(_,_,k_block,read_stage), tCrS_copy_view(_,_,k_block)); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - // Nothing extra to do - } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto tCsZ = cute::get<2>(partitioned_mma_extra_info); - auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); - copy(smem_tiled_copy_S, tCsZ(_,_,k_block,read_stage), tCrZ_copy_view(_,_,k_block)); - } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - } - - /// (Designed for separate transform pipeline in Blackwell) - /// Utilities to copy extra inputs from smem to RF - template - CUTLASS_DEVICE - static void copy_scale_zeros_for_transform( - cute::tuple & partitioned_transform_extra_info, - int load2transform_consumer_index) { - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - } - else if constexpr (ModeHasScales) { - auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); - auto&& scales = cute::get<1>(partitioned_transform_extra_info); - using ScaleType = decltype(scales); - auto tSrS = make_tensor(scales.data(), scales.layout()); - auto tSsS = cute::get<2>(partitioned_transform_extra_info); - copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - // Nothing extra to do - } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto&& zeros = cute::get<3>(partitioned_transform_extra_info); - using ZeroType = decltype(zeros); - auto tZrZ = make_tensor(zeros.data(), zeros.layout()); - auto tZsZ = cute::get<4>(partitioned_transform_extra_info); - copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); - - } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - - // Helper functions to select packing for conversion - template - struct select_packing { // Naive packing policy - static constexpr auto value() { - return Int, sizeof_bits_v))>{}; - } - }; - - // The core converter uses a lookup table to converts i4 -> 8 bit value. - template - CUTLASS_DEVICE - static void lookup_table_convert( // Accept mutable temporaries - Tensor const& src, - Tensor && dst, - Tensor const& scales_neg, - Tensor const& scales_pos) { - - lookup_table_convert(src, dst, scales_neg, scales_pos); - } - template - CUTLASS_DEVICE - static void lookup_table_convert( - Tensor const& src, - Tensor & dst, - Tensor const& scales_neg, - Tensor const& scales_pos) { - - constexpr int N = cute::cosize(LayoutIn{}); - static_assert(N == 4 || N == 8); - static_assert(cosize(LayoutScale{}) <= N / 4, - "at least 4 consecutive weights must share the same scale."); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - using RegArray = cutlass::AlignedArray; - - // View the input as reg - auto&& src_reg = cute::recast(src)(0); - auto&& r = cute::recast(dst)(0); - - // Determines if to get from the signed or unsigned candidates - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 - asm volatile( - "{\n" - " lop3.b32 %0, %1, %2, %3, %4;\n" \ - "}\n" - : "=r"(sign) - : "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut) - ); - sign = sign >> 1; - - // Ignore sign bit when indexing into LUT - uint32_t lut_idx = src_reg & 0x77777777; - Tensor scales_neg_ = cute::filter(scales_neg); - Tensor scales_pos_ = cute::filter(scales_pos); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i, lut_idx >>=16, sign >>=16) { - auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_(i)); - auto&& scale_pos_ = reinterpret_cast const&>(scales_pos_(i)); - asm volatile( - "{\n" - " .reg .b32 pos, neg ;\n" \ - " prmt .b32 neg, %3, %4, %1 ;\n" \ - " prmt .b32 pos, %5, %6, %1 ;\n" \ - " prmt .b32 %0, pos, neg, %2 ;\n" \ - "}\n" - : "=r"(r[i]) - : "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]) - ); - } - } - - /// Utilities to dequantize A. - template - CUTLASS_DEVICE - static void static_check_scale(Layout const& tensor) { - static_assert(shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, "At least 4 adjacent weights in a thread must share the same scale."); - } - template - CUTLASS_DEVICE - static void static_check_scale(Tensor const& tensor) { - static_check_scale(flatten(Layout{})); - } - template - CUTLASS_DEVICE - static void dequantize_A_kblock( - Tensor const& tCrA_load, - Tensor& tCrA_mma, - cute::tuple& partitioned_extra_info, - int const k_block) { - - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(cosize_v == cosize_v); - static_assert(size_v == cosize_v); - static_assert(size_v == cosize_v); - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - Tensor src = tCrA_load(_, _, k_block); - Tensor dst = tCrA_mma(_, _, k_block); - - CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()), - "The first mode of tensor src must be contiguous in memory"); - // try to make the size of the first mode equal to 32bit - int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, - ceil_div(32, sizeof_bits_v)); - Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, Int{})); - Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, Int{})); - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); - } - } - else if constexpr (UseScaleLookupTable) { - constexpr int num_elements = decltype(size(src))::value; - static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); - static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); - static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); - - Tensor tCrS_neg = cute::get<1>(partitioned_extra_info); - auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed - Tensor scales_neg = tCrS_neg(_, _, k_block); - Tensor scales_pos = tCrS_pos(_, _, k_block); - CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg)); - - static_check_scale(scales_neg); - static_check_scale(scales_pos); - Tensor scales_neg_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_neg, Int{})); - Tensor scales_pos_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_pos, Int{})); - - if (k_block == 0) { - Tensor scales_neg_vm_ = filter(scales_neg_vm); - Tensor scales_pos_vm_ = filter(scales_pos_vm); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) - { - auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); - auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); - constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - asm volatile( - "{\n" - " lop3 .b32 %0, %2, %4, %5, %6;\n" \ - " xor .b32 %1, %3, %5; \n" \ - "}\n" - : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) - : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) - ); - } - } - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i)); - } - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); - CUTE_STATIC_ASSERT_V(size(src) == size(scales)); - Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); - - if constexpr (is_same_v) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<0>(dst_vm); ++j) { - dst_vm(j, i) *= scales_vm(j, i); - } - } - } - else { - auto stage = make_tensor_like(src_vm(_, 0)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), stage); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<0>(dst_vm); ++j) { - stage(j) *= scales_vm(j, i); - } - LayoutAwareConvert(stage, dst_vm(_, i)); - } - } - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - static_assert(is_same_v, "ElementScale and ElementZero must be the same."); - Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); - Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); - CUTE_STATIC_ASSERT_V(size(src) == size(scales)); - CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); - Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); - Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int{})); - - if constexpr (is_same_v) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<0>(dst_vm); ++j) { - dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i); - } - } - } - else { - auto stage = make_tensor_like(src_vm(_, 0)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), stage); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<0>(dst_vm); ++j) { - stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i); - } - LayoutAwareConvert(stage, dst_vm(_, i)); - } - } - } - else { - static_assert(cutlass::detail::dependent_false, "No A data is loaded."); - } - } - - /// (Designed for separate transform pipeline in Blackwell) - /// Utilities to dequantize A. - template - CUTLASS_DEVICE - static void dequantize_A_kblock_for_transform( - Tensor const& tArA, - Tensor& tArACompute, - cute::tuple const& partitioned_extra_info, - int const k_block) { - - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(cosize_v == cosize_v); - static_assert(size_v == cosize_v); - static_assert(size_v == cosize_v); - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - auto src = tArA(_, _, _, k_block); - auto dst = tArACompute(_, _, _, k_block); - constexpr int num_elements = decltype(size(src))::value; - - constexpr int pack = decltype(select_packing::value())::value; - using Converter = cutlass::NumericArrayConverter; - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - constexpr int DstElementsPerReg = 32 / sizeof_bits_v; - using RegArray = cutlass::AlignedArray; - - auto src_arr = recast(src); - auto dst_arr = recast(dst); - - Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); - - cute::transform(src_arr, dst_arr, Converter::convert); - - if constexpr (ModeHasScales) { - - auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block); - - CUTE_STATIC_ASSERT_V(size(src) == size(scales)); - - if constexpr (is_same_v) { - - using ScaleArray = cutlass::Array; - auto scale_arr = recast(filter_zeros(scales)); - - if constexpr (is_same_v){ - Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); - - for (int i = 0; i < size<1>(dst_vm); ++i){ - auto&& r = cute::recast(dst_vm(_,i))(0); - auto&& scale_reg = cute::recast(scales_vm(_,i))(0); - CUTLASS_PRAGMA_UNROLL - for (size_t ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hmul2(bf16x2_val, - reinterpret_cast(scale_reg[ii])); - } - } - } - else{ - cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); - } - } - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - // Do Nothing - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - static_assert(is_same_v, "ElementScale and ElementZero must be the same."); - - auto const& zeros = cute::get<3>(partitioned_extra_info)(_,_,_,k_block); - CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); - - if constexpr (is_same_v) { - using ZeroArray = cutlass::Array; - auto zero_arr = recast(filter_zeros(zeros)); - - if constexpr (is_same_v) { - Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, pack)); - - for (int i = 0; i < size<1>(dst_vm); ++i){ - auto&& r = cute::recast(dst_vm(_,i))(0); - auto&& zero_reg = cute::recast(zeros_vm(_,i))(0); - CUTLASS_PRAGMA_UNROLL - for (size_t ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hadd2(bf16x2_val, - reinterpret_cast(zero_reg[ii])); - } - } - } - else{ - cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{}); - } - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); - } - } -} - - - /// Utilities for any additional inputs inside of the TMA load - template < - class Params, - class TensorStorage, - class... Ts - > - CUTLASS_DEVICE - static auto partition_extra_tma_inputs( - Params const& mainloop_params, - cute::tuple const& load_inputs, - TensorStorage& shared_tensors, - uint2 const& cluster_local_block_id, - int const m_coord, - int const l_coord) { - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) - Tensor gS_mkl = get<2>(load_inputs); - auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); - Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - - Tensor tSgS = block_tma_s.partition_S(gS); - Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return cute::make_tuple(tSgS, tSsS); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) - Tensor gZ_mkl = get<3>(load_inputs); - auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); - Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - - Tensor tZgZ = block_tma_z.partition_S(gZ); - Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) - return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); - } - } - - /// Utilities for partitioning extra inputs for loading from smem in the mainloop. - template < - class ThreadMma, - class TensorStorage - > - CUTLASS_DEVICE - static auto partition_extra_mma_info( - ThreadMma const& mma_thread_slice, - TensorStorage& shared_tensors) { - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - return cute::make_tuple(); - } - else if constexpr (UseScaleLookupTable) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); - Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); - } - } - else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return cute::make_tuple(tCsS, tCrS); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); - return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - - template < - class TiledMma, - class TiledCopy, - class TensorStorage - > - CUTLASS_DEVICE - static auto partition_extra_transform_info( - TiledMma const& tiled_mma, - TiledCopy const& smem_tiled_copy_S, - TensorStorage& shared_storage) { - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); - auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(threadIdx.x % 128); - - Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = cta_mma.partition_A(sS); - Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); - Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = cta_mma.partition_A(sZ); - Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); - Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); - return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - - /// Returns the tiled copy and copy views for the extra inputs. - template - CUTLASS_DEVICE - static auto retile_extra_mma_info( - TiledMma const& tiled_mma, - cute::tuple& partitioned_extra_info, - int const warp_group_thread_idx) { - - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); - auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); - Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } -}; - -} // cutlass::gemm::collective::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp deleted file mode 100644 index 04120a41ae0f404ed22ed05d08f138526b8e9fc3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Kernel type definitions specific for SM103 BlockScaled MMA -*/ - -#pragma once - -namespace cutlass::sm103::detail { - -enum class KernelPrefetchType { - TmaPrefetch, // TMA Prefetch (is the default version) - Disable // Disable Prefetch -}; - -} // namespace cutlass::sm103::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp deleted file mode 100644 index d2dd6a16a67c12beece2645bf4781b820e07e78e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp +++ /dev/null @@ -1,86 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::detail { - -/// @brief A bool constant that depends on one or more template parameters. -/// -/// For more detailed documentation and use cases, -/// please see `dependent_false` below. -template -inline constexpr bool dependent_bool_value = Value; - -/// @brief An always-false value that depends on one or more template parameters. -/// -/// This exists because `static_assert(false);` always fails, -/// even if it occurs in the `else` branch of an `if constexpr`. -/// The following example shows how to use `dependent_false` in that case. -/// -/// @code -/// template -/// void foo (T t) -/// { -/// if constexpr (std::is_integral_v) { -/// do_integer_stuff(t); -/// } -/// else if constexpr (std::is_floating_point_v) { -/// do_floating_point_stuff(t); -/// } -/// else { -/// static_assert(dependent_false, "T must be " -/// "an integral or floating-point type."); -/// } -/// } -/// @endcode -/// -/// This implements the C++ Standard Library proposal P1830R1. -/// -/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf -/// -/// That proposal is under review as of 2022/12/05. -/// The following link shows P1830's current review status. -/// -/// https://github.com/cplusplus/papers/issues/572 -/// -/// P2593R0 proposes an alternate solution to this problem, -/// that would change the C++ language itself. -/// -/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html -/// -/// For headers in this library, however, we only consider library solutions -/// as work-arounds for future C++ features. -template -inline constexpr bool dependent_false = dependent_bool_value; - -} // end namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp deleted file mode 100644 index cf9b803b27b3b148e1441260471ecab99e82bfd3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp +++ /dev/null @@ -1,242 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Helper macros for the CUTLASS library -*/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -#ifdef CUTLASS_NAMESPACE -#define concat_tok(a, b) a ## b -#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) -#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#elif defined(__CUDACC_RTC__) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#else -#define CUTLASS_HOST_DEVICE inline -#define CUTLASS_DEVICE inline -#endif - -#if ! defined(_MSC_VER) -#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline)) -#else -#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]] -#endif - -#define CUTLASS_HOST __host__ -#define CUTLASS_GLOBAL __global__ static - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) -{ } - -#if defined(__GNUC__) - #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) -#else - #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) -#endif - -#ifdef _MSC_VER -// Provides support for alternative operators 'and', 'or', and 'not' -#include -#endif // _MSC_VER - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#if defined(__CUDA_ARCH__) - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } - #else - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } - #endif -#else - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) - #else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) - #endif -#endif - -// CUTLASS_CMATH_NAMESPACE is the namespace where code can find -// functions like isnan and log. Such functions are in -// the std namespace in host code, but in the global namespace -// in device code. -// -// The intended use case for this macro is in "using" declarations -// for making argument-dependent lookup (ADL) work in generic code. -// For example, if T is cutlass::half_t, the following code will -// invoke cutlass::isnan(half_t). If T is float, it will invoke -// std::isnan on host and ::isnan on device. (CUTLASS's support -// for NVRTC prevents it from using things in the std namespace -// in device code.) Correct use of "using" declarations can help -// avoid unexpected implicit conversions, like from half_t to float. -// -// template -// bool foo(T x) { -// using CUTLASS_CMATH_NAMESPACE :: isnan; -// return isnan(x); -// } -// -// Without this macro, one would need to write the following. -// -// template -// bool foo(T x) { -// #if defined(__CUDA_ARCH__) -// using ::isnan; -// #else -// using std::isnan; -// #endif -// return isnan(x); -// } - -#if defined(__CUDA_ARCH__) -# define CUTLASS_CMATH_NAMESPACE -#else -# define CUTLASS_CMATH_NAMESPACE std -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - - -#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 -#endif - - -// CUDA 10.1 introduces the mma instruction -#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) -#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CUTLASS_ASSERT(x) assert(x) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) - #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) - #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") - #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") - #else - #define CUTLASS_PRAGMA_UNROLL #pragma unroll - #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 - #endif - - #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL - -#else - - #define CUTLASS_PRAGMA_UNROLL - #define CUTLASS_PRAGMA_NO_UNROLL - #define CUTLASS_GEMM_LOOP - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if !defined(__CUDACC_RTC__) -#define CUTLASS_THREAD_LOCAL thread_local -#else -#define CUTLASS_THREAD_LOCAL -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(_MSVC_LANG) -# define CUTLASS_CPLUSPLUS _MSVC_LANG -#else -# define CUTLASS_CPLUSPLUS __cplusplus -#endif - -// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/n4762.pdf -// Section 14.8 Predefined macro names -#if (201703L <= CUTLASS_CPLUSPLUS) -#define CUTLASS_CONSTEXPR_IF_CXX17 constexpr -#define CUTLASS_CXX17_OR_LATER 1 -#else -#define CUTLASS_CONSTEXPR_IF_CXX17 -#define CUTLASS_CXX17_OR_LATER 0 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// __CUDA_ARCH_SPECIFIC__ is introduced in CUDA 12.9 -#if !defined(CUDA_ARCH_CONDITIONAL) - -#if defined(__CUDA_ARCH_SPECIFIC__) -#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (__CUDA_ARCH_SPECIFIC__ == ARCH_XXYY) -#else -#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (false) -#endif - -#endif - -// __CUDA_ARCH_FAMILY_SPECIFIC__ is introduced in CUDA 12.9 -#if !defined(CUDA_ARCH_FAMILY) - -#if defined(__CUDA_ARCH_FAMILY_SPECIFIC__) -#define CUDA_ARCH_FAMILY(ARCH_XXYY) (__CUDA_ARCH_FAMILY_SPECIFIC__ == ARCH_XXYY) -#else -#define CUDA_ARCH_FAMILY(ARCH_XXYY) (false) -#endif - -#endif - -#if !defined(CUDA_ARCH_CONDITIONAL_OR_FAMILY) -#define CUDA_ARCH_CONDITIONAL_OR_FAMILY(ARCH_XXYY) \ - (CUDA_ARCH_CONDITIONAL(ARCH_XXYY) || CUDA_ARCH_FAMILY(ARCH_XXYY)) -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -}; // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp deleted file mode 100644 index e1c1bd6c5529ccb7af5c70e558be327c81396106..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp +++ /dev/null @@ -1,434 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/layout.hpp" -#include "cute/pointer_sparse.hpp" // cute::is_sparse -#include "cute/swizzle.hpp" // cute::Swizzle -#include "cute/swizzle_layout.hpp" // cute::get_swizzle_portion -#include "cute/util/type_traits.hpp" -#include "cute/arch/copy_sm90_tma.hpp" -#include "cute/arch/copy_sm100_tma.hpp" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/detail/collective.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::detail { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// For each cutlass::layout, provides its corresponding cute stride types, 64b by default - -template -struct TagToStrideA { - using type = L; -}; - -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using type = cute::Stride, int64_t>; - using tag = layout::RowMajor; -}; - -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using type = cute::Stride, int64_t, int64_t>; - using tag = layout::ColumnMajor; -}; - -template -struct TagToStrideB { - using type = L; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using type = cute::Stride, int64_t, int64_t>; - using tag = layout::RowMajor; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using type = cute::Stride, int64_t>; - using tag = layout::ColumnMajor; -}; - -// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default -// Used by pointer array and grouped gemm -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using UnderlyingType = cute::Stride, cute::Int<0>>; - using type = UnderlyingType*; - using tag = layout::RowMajor; -}; - -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; - using type = UnderlyingType*; - using tag = layout::ColumnMajor; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; - using type = UnderlyingType*; - using tag = layout::RowMajor; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using UnderlyingType = cute::Stride, cute::Int<0>>; - using type = UnderlyingType*; - using tag = layout::ColumnMajor; -}; - -// Maps to modes [M, N, L] -template -struct TagToStrideC : TagToStrideA { }; - -// Conv: Maps to modes ((P,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<1>, cute::Int<0>>; -}; - -// Conv: Maps to modes ((P,Q,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<1>, cute::Int<0>>; -}; - -// Conv: Maps to modes ((P,Q,Z,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<1>, cute::Int<0>>; -}; - -// Conv: Maps to modes (K, (C,S), _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t>, cute::Int<0>>; -}; - -// Conv: Maps to modes (K, (C,S,R), _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t, int64_t>, cute::Int<0>>; -}; - -// Conv: Maps to modes (K, (C,S,R,T), _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t, int64_t, int64_t>, cute::Int<0>>; -}; - -// Conv: Maps to modes ((C,S), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t>, int64_t, cute::Int<0>>; -}; - -// Conv: Maps to modes ((C,S,R), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t, int64_t>, int64_t, cute::Int<0>>; -}; - -// Conv: Maps to modes ((C,S,R,T), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, int64_t, int64_t, int64_t>, int64_t, cute::Int<0>>; -}; - -// Convenience aliases -template -using TagToStrideA_t = typename TagToStrideA::type; - -template -using TagToStrideB_t = typename TagToStrideB::type; - -template -using TagToStrideC_t = typename TagToStrideC::type; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// For 2.x compatibility APIs, provide stride->layout tag mappers - -template -constexpr bool -is_major(Stride = {}) { - // Account for stride types with and without batch mode and batch modes with static zero stride - return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; -} - -template -constexpr bool -is_major(cute::Layout = {}) { - return is_major(Stride{}); -} - -// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices -template -constexpr -auto -stride_to_layout_tag_A() { - using InternalStrideA = cute::remove_pointer_t; - if constexpr (cute::is_layout::value) { - return stride_to_layout_tag_A(); - } - else if constexpr (is_major<0, StrideA>()) { // M major - return layout::ColumnMajor{}; - } - // Specialize for sparse layout - else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && - cute::rank(cute::get<1>(InternalStrideA{})) == 2 && - cute::is_same_v(InternalStrideA{}))>>) { - return layout::ColumnMajor{}; - } - else { // K major - return layout::RowMajor{}; - } - - CUTE_GCC_UNREACHABLE; -} - -template -constexpr -auto -stride_to_layout_tag_B() { - using InternalStrideB = cute::remove_pointer_t; - if constexpr (cute::is_layout::value) { - return stride_to_layout_tag_B(); - } - else if constexpr (is_major<0, StrideB>()) { // N major - return layout::RowMajor{}; - } - else { // K major - return layout::ColumnMajor{}; - } - - CUTE_GCC_UNREACHABLE; -} - -template -constexpr -auto -stride_to_layout_tag_C() { - using InternalStrideC = cute::remove_pointer_t; - if constexpr (cute::is_layout::value) { - return stride_to_layout_tag_C(); - } - else if constexpr (is_major<0, StrideC>()) { // M major - return layout::ColumnMajor{}; - } - else { // N major - return layout::RowMajor{}; - } - - CUTE_GCC_UNREACHABLE; -} - -// Utilities to map Stride back on to their corresponding layout tags -template -struct StrideToLayoutTagA { - using type = decltype(detail::stride_to_layout_tag_A()); -}; - -template -struct StrideToLayoutTagB { - using type = decltype(detail::stride_to_layout_tag_B()); -}; - -template -struct StrideToLayoutTagC { - using type = decltype(detail::stride_to_layout_tag_C()); -}; - -// Convenience aliases -template -using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; - -template -using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; - -template -using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Inspects a tiled copy and whether its copy engine is TMA or not -template -constexpr bool is_tma_copy_engine() { - if constexpr (cute::is_void_v) { - return false; - } - else { - if constexpr ( cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - ) { - return true; - } - } - return false; -} - -template -struct RawDtype { using type = X; }; - -template -struct RawDtype> { using type = typename X::raw_type; }; - - -// Inspects a TiledCopy and returns its alignment in terms of element count -template -constexpr int -get_alignment_count_from_gmem_tiled_copy() { - - if constexpr (cute::is_void_v) { - return 1; - } - - // Account for ElementC = void kernels - else if constexpr (cute::is_void_v) { - return 0; - } - - else { - // For TMA tiled copies, we know the alignment has to be 128 bits - if constexpr (is_tma_copy_engine()) { - if constexpr ( cute::is_same_v::type, cutlass::detail::float_e2m1_unpacksmem_t> || - cute::is_same_v::type, cutlass::detail::float_e3m2_unpacksmem_t> || - cute::is_same_v::type, cutlass::detail::float_e2m3_unpacksmem_t> || - cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float4_unpacksmem_t> || - cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float6_unpacksmem_t> || - cutlass::gemm::collective::detail::is_sm10x_f8f6f4_element() && cute::is_same_v::type, uint8_t>) { - return 128; - } - - // For sparse MMA, alignment in logical elements is increased by sparsity factor - if constexpr (cute::is_sparse_v) { - return 128 / sizeof_bits::value * ElementMma::sparsity; - } - return 128 / sizeof_bits::value; - } - else { - // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN - return GmemTiledCopy::NumValSrc; - } - } -} - -// Return alignment bit requirements for the GEMM inputs. -template < - class ElementType - , bool IsF8F6F4SubBytes=false -> -constexpr int -get_input_alignment_bits() { - if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 4) { - // 16U4 format: The inner tensor size dimension should be multiple of 64B. - return 64 * 8; - } - else if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 6) { - // 16U6 format : The inner tensor size dimension must be a multiple of 96B. - return 96 * 8; - } - // TMA 16B alignment requirement - return 128; -} - -// Return alignment bit requirements for the GEMM outputs. -template -constexpr int -get_output_alignment_bits() { - if constexpr (sizeof_bits::value == 6) { - // 16U6 format : The inner tensor size dimension must be a multiple of 96B. - return 96 * 8; - } - // TMA 16B alignment requirement - return 128; -} - -// Check if tensor layout satisfies a given major alignment -template -CUTLASS_HOST_DEVICE constexpr -bool -check_alignment(cute::Layout const& layout) { - // Condition: shape must divide by Alignment without rounding - bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast(layout)); - // Condition: every dynamic stride must be a multiple of Alignment - bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static::value || (s % Alignment == 0); }); - return shape_check && stride_check; -} - -// Check if tensor layout satisfies a given major alignment -template -CUTLASS_HOST_DEVICE constexpr -bool -check_alignment(Shape const& shape, Stride const& stride) { - return check_alignment(cute::make_layout(shape, stride)); -} - -template -CUTLASS_HOST_DEVICE constexpr -size_t -alignment_for_swizzle(cute::Swizzle) { - static_assert(B >= 0 and M >= 0); - return size_t(1) << size_t(B + M + cute::abs(S)); -} - -template -CUTLASS_HOST_DEVICE constexpr -size_t -alignment_for_swizzle(Layout layout) { - return alignment_for_swizzle(cute::get_swizzle_portion(layout)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp deleted file mode 100644 index 84de1c7d3c9359b94ababf184a2c2db724236b11..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp +++ /dev/null @@ -1,75 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Mainloop Fusion configs specific for scale factors -*/ - -#pragma once - -#include // cute::void_t - -namespace cutlass::detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct ElementSFType { - using type = void; -}; - -template -struct ElementSFType> { - using type = typename CollectiveMainloop::ElementSF; -}; - -template -struct LayoutSFAType { - using type = void; -}; - -template -struct LayoutSFAType> { - using type = typename CollectiveMainloop::LayoutSFA; -}; - -template -struct LayoutSFBType { - using type = void; -}; - -template -struct LayoutSFBType> { - using type = typename CollectiveMainloop::LayoutSFB; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp deleted file mode 100644 index b4cbd3864a7fbfc524229cb183c62564cead1e7f..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp +++ /dev/null @@ -1,87 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/arch/mma.h" -#include "cute/layout.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::detail { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct IsSparseTensorOp : cute::false_type { }; - -// TiledMma for sparse must have ValTypeE -template -struct IsSparseTensorOp> - : cute::true_type { }; - - -template -struct IsBlockScaledTensorOp : cute::false_type { }; - -// TiledMma for blockScaled must have FrgTypeSFA -template -struct IsBlockScaledTensorOp> - : cute::true_type { }; - - -// The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. -template -struct get_operator_class { - static constexpr bool is_sparse_op = IsSparseTensorOp::value; - static constexpr bool is_block_scaled_op = IsBlockScaledTensorOp::value; - // All tensorop operations have atom shape's M >= 8 - static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; - using type = cute::conditional_t< - is_tensor_op, - cute::conditional_t< - is_sparse_op, - cutlass::arch::OpClassSparseTensorOp, - cute::conditional_t< - is_block_scaled_op, - cutlass::arch::OpClassBlockScaledTensorOp, - cutlass::arch::OpClassTensorOp - > - >, - cutlass::arch::OpClassSimt - >; -}; - -template -using get_operator_class_t = typename get_operator_class::type; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp deleted file mode 100644 index e4f20cb237cb9b6960275339ee803e00f8e40031..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp +++ /dev/null @@ -1,242 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Blocked Scale configs specific for SM100 BlockScaled MMA -*/ - -#pragma once - -#include "cutlass/layout/matrix.h" - -#include "cute/int_tuple.hpp" -#include "cute/atom/mma_traits_sm100.hpp" - -namespace cutlass::detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// -using namespace cute; - -template -struct Sm1xxBlockScaledBasicChunk { - - using Blk_MN = _128; - using Blk_SF = _4; - - using SfKMajorAtom = Layout< Shape< Shape<_32,_4>, Shape, _4>>, - Stride, Stride< _0, _1>>>; - using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_32,_4>>, - Stride, Stride<_16,_4>>>; - using SfAtom = cute::conditional_t; -}; - -template -struct Sm1xxBlockScaledConfig { - // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. - // k-major order - static constexpr int SFVecSize = SFVecSize_; - using Sm1xxBlkScaledChunk = Sm1xxBlockScaledBasicChunk; - using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; - using SfAtom = typename Sm1xxBlkScaledChunk::SfAtom; - - using LayoutSF = decltype(blocked_product(SfAtom{}, make_layout( make_shape(int32_t(0), int32_t(0), int32_t(0)), - make_stride(int32_t(0), _1{}, int32_t(0))))); - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFA() { - return LayoutSF{}; - } - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFB() { - return LayoutSF{}; - } - - // The following function is provided for user fill dynamic problem size to the layout_SFA. - template < class ProblemShape, class LayoutSFA = LayoutSF> - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); - } - - // The following function is provided for user fill dynamic problem size to the layout_SFB. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); - } - - template - CUTE_HOST_DEVICE - static constexpr auto - deduce_smem_layoutSFA(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { - - constexpr int MMA_NSF = TiledMma::K / SFVecSize; - // Basic storage block for new Scaling Factor Layouts - using mnBasicBlockShape = Shape<_32,_4>; - using mnBasicBlockStride = Stride<_16,_4>; - using kBasicBlockShape = Shape, Int>; - using kBasicBlockStride = Stride<_0, _1>; - - // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) - using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), - cute::size<2>(TileShape_MNK{})))); - // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) - using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), - cute::size<2>(TileShape_MNK{})))); - // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). - // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size - using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; - using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); - - using TL_VMNK = typename TiledMma::ThrLayoutVMNK; - constexpr TL_VMNK tl_vmnk{}; - constexpr int MMA_M = cute::size<0>(TileShape_MNK{}) / cute::size<0>(tl_vmnk); - using mma_SFA_shape = decltype( make_shape( prepend(Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); - using mma_SFA_stride = decltype(make_stride( prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); - using sSFA_shape = decltype( make_shape( mma_SFA_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); - using sSFA_stride = decltype(make_stride(mma_SFA_stride{}, _0{}, make_stride( Int{}, Int{}))); - using SmemLayoutAtomSFA = decltype(make_layout(sSFA_shape{}, sSFA_stride{})); - return SmemLayoutAtomSFA{}; - } - - template - CUTE_HOST_DEVICE - static constexpr auto - deduce_smem_layoutSFB(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { - - constexpr int MMA_NSF = TiledMma::K / SFVecSize; - // Basic storage block for new Scaling Factor Layouts - using mnBasicBlockShape = Shape<_32,_4>; - using mnBasicBlockStride = Stride<_16,_4>; - using kBasicBlockShape = Shape, Int>; - using kBasicBlockStride = Stride<_0, _1>; - - // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) - using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), - cute::size<2>(TileShape_MNK{})))); - // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) - using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), - cute::size<2>(TileShape_MNK{})))); - // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). - // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size - using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; - using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); - - using TL_VMNK = typename TiledMma::ThrLayoutVMNK; - constexpr TL_VMNK tl_vmnk{}; - constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); - // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. - // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. - constexpr int MMA_N_SFB = cutlass::ceil_div(MMA_N, Blk_MN{}) * Blk_MN{}; - using mma_SFB_shape = decltype(make_shape( prepend( Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); - using mma_SFB_stride = decltype(make_stride(prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); - using sSFB_shape = decltype( make_shape( mma_SFB_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); - using sSFB_stride = decltype(make_stride(mma_SFB_stride{}, _0{}, make_stride( Int{}, Int{}))); - using SmemLayoutAtomSFB = decltype(make_layout(sSFB_shape{}, sSFB_stride{})); - return SmemLayoutAtomSFB{}; - } -}; - - -template -struct Sm1xxBlockScaledOutputConfig { - // We are creating the SFD tensors' layouts in the collective. - // k-major order - static constexpr int SFVecSize = SFVecSize_; - using Sm1xxBlkScaledChunk = cutlass::detail::Sm1xxBlockScaledBasicChunk; - using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; - using SfAtom = typename Sm1xxBlkScaledChunk::SfAtom; - - using LayoutKMajorSF = decltype(blocked_product(SfAtom{}, make_layout(make_shape (int32_t(0), int32_t(0), int32_t(0)), - make_stride(int32_t(0), _1{}, int32_t(0))))); - - using LayoutMNMajorSF = decltype(blocked_product(SfAtom{}, make_layout(make_shape (int32_t(0), int32_t(0), int32_t(0)), - make_stride( _1{}, int32_t(0), int32_t(0))))); - - using LayoutSF = cute::conditional_t; - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFD() { - return LayoutSF{}; - } - - // The following function is provided for user fill dynamic problem size to the layout_SFC. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFD(ProblemShape problem_shape, LayoutSFD layout_sfc = LayoutSFD{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - if constexpr (major == UMMA::Major::K) { - return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); - } - else { - return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_1,_2,_3>{}); - } - } -}; - -//// Describe the Scalefactor Tensor without VectorSize -struct Sm1xxBlockScaledTensorConfig { - // k-major order - // The blockscaled tensor does not need to know vectorsize - using Blk_M = _128; - using Blk_N = _4; - using SfAtom = Layout< Shape< Shape<_32,_4>, Shape<_4>>, - Stride, Stride<_1>>>; - - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape(ProblemShape problem_shape) { - auto problem_shape_MNL = append<3>(problem_shape, 1); - auto [M, N, L] = problem_shape_MNL; - return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp deleted file mode 100644 index b6c92c4d1199995029a550f61dce6a9903d7333e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp +++ /dev/null @@ -1,182 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA -*/ - -#pragma once - -#include "cutlass/layout/matrix.h" - -#include "cute/int_tuple.hpp" -#include "cute/atom/mma_traits_sm100.hpp" - -namespace cutlass::detail{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -using namespace cute; - -template -struct Sm100MixedInputBlockwiseScaleConfig { - - using ShapeScale = Shape, int32_t>, Shape, int32_t>, int32_t>; - - using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using LayoutScale = Layout; - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layout_scale() { - return LayoutScale{}; - } - - template - CUTE_HOST_DEVICE - static constexpr auto - smem_atom_layout_scale(CtaShape_MN_K cta_shape_mn_k) { - static_assert(cute::is_static_v, "Expect static CTA shape"); - - int constexpr size_MN = cute::get<0>(CtaShape_MN_K{}); - int constexpr size_K = cute::get<1>(CtaShape_MN_K{}); - - int constexpr SmemSizeMN = (SFVecSizeMN < size_MN) - ? SFVecSizeMN - : size_MN; - - int constexpr SmemSizeK = (SFVecSizeK < size_K) - ? SFVecSizeK - : size_K; - - int constexpr div_MN = cute::ceil_div(size_MN, SmemSizeMN); - int constexpr div_K = cute::ceil_div(size_K, SmemSizeK); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int{})); - } - else { - return make_stride(make_stride(_0{}, Int{}), make_stride(_0{}, _1{})); - } - }(); - - return make_layout( - make_shape(make_shape(Int{}, Int{}), - make_shape(Int{}, Int{})), - strides - ); - } - - - - // The following function is provided for user fill dynamic problem size to the layout_SFA. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_scale(ScaledInputDim scale_input_dims) { - const auto scale_input_dims_MNKL = append<3>(scale_input_dims, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [MN, K, L] = scale_input_dims_MNKL; - if constexpr (majorSFA == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, SFVecSizeMN))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); - } - }(); - - auto [MN, K, L] = scale_input_dims_MNKL; - auto mk_layout = make_layout( - make_shape(make_shape(Int{}, cute::ceil_div(MN, SFVecSizeMN)), - make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), - strides - ); - - return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); - } - -}; - -template -struct RuntimeMixedInputBlockwiseScaleConfig { - - using ShapeScale = Shape, Shape, int32_t>; - - using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, - Stride,Stride<_0,_1>, int32_t>>; - - using LayoutScale = Layout; - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layout_scale() { - return LayoutScale{}; - } - - // The following function is provided for user fill dynamic problem size to the layout_S. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_scale(ProblemShape problem_shape, SFVecShape sf_vec_shape) { - auto problem_shape_MNKL = append<3>(problem_shape, 1); - - auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - auto [MN, K, L] = problem_shape_MNKL; - auto [sfmn, sfk] = sf_vec_shape; - if constexpr (majorScale == UMMA::Major::MN) { - return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, sfmn))); - } - else { - return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); - } - }(); - - auto [MN, K, L] = problem_shape_MNKL; - auto [sfmn, sfk] = sf_vec_shape; - auto mk_layout = make_layout( - make_shape(make_shape(sfmn, cute::ceil_div(MN, sfmn)), - make_shape(sfk, cute::ceil_div(K, sfk))), - strides - ); - - return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp deleted file mode 100644 index f12bac12dc898f177d95438ca10a1b060f4402ac..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief TMEM Accumulator Helpers for SM100 -*/ - -#pragma once - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - - -namespace cutlass::detail{ -constexpr uint32_t TmemColMask = 0x0000'FFFF; - -template -CUTE_HOST_DEVICE -static constexpr auto find_tmem_tensor_col_offset(TmemTensor tensor) { - using namespace cute; - return cosize(recast(tensor).layout()) & TmemColMask; -} - -template -CUTE_HOST_DEVICE -static constexpr auto make_sm100_accumulator(TiledMma tiled_mma, AccumulatorShape acc_shape, EpilogueTile epilogue_tile) { - using namespace cute; - static_assert(rank(acc_shape) == 3 || (rank(acc_shape) == 4 && IsOverlappingAccum == false), - "Expect a rank >= 3 accumulator shape compatible with an SM100 tiled mma, Overlapping accumulators is only available for non-complex kernels"); - if constexpr (IsOverlappingAccum) { - Tensor accumulators_tmp = TiledMma::make_fragment_C(append(acc_shape, Int<2>{})); - return make_tensor( - accumulators_tmp.data(), - shape(accumulators_tmp), - replace<3>( - stride(accumulators_tmp), - Int<(256 - size<1>(EpilogueTile{})) * stride<0, 1>(accumulators_tmp.layout())>{})); - } else { - return TiledMma::make_fragment_C(append( - acc_shape, - Int{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) - } -} -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp deleted file mode 100644 index 300448d7cd0273ac4572b5484d5780d98576d4b5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Blocked Scale configs specific for SM103 BlockScaled MMA -*/ - -#pragma once - -#include "cutlass/layout/matrix.h" - -#include "cute/int_tuple.hpp" -#include "cute/atom/mma_traits_sm100.hpp" - -namespace cutlass::detail{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -using namespace cute; - -template -struct Sm103BlockScaledBasicChunk { - - using Blk_MN = _128; - using Blk_SF = _4; - - using SfKMajorAtom = Layout< Shape< Shape< _8, _4, _4>, Shape, _4>>, - Stride, Stride< _0, _1>>>; - using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_8, _4, _4>>, - Stride, Stride<_16,_128, _4>>>; - using SfAtom = cute::conditional_t; -}; - -template -struct Sm103BlockScaledConfig { - // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. - // k-major order - static constexpr int SFVecSize = SFVecSize_; - using Sm103BlkScaledChunk = Sm103BlockScaledBasicChunk; - using Blk_MN = typename Sm103BlkScaledChunk::Blk_MN; - using Blk_SF = typename Sm103BlkScaledChunk::Blk_SF; - using SfAtom = typename Sm103BlkScaledChunk::SfAtom; - - using LayoutSF = decltype(tile_to_shape(SfAtom{}, make_shape(int(0),int(0),int(0)),Step<_2,_1,_3>{})); - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFA() { - return LayoutSF{}; - } - - CUTE_HOST_DEVICE - static constexpr auto - deduce_layoutSFB() { - return LayoutSF{}; - } - - // The following function is provided for user fill dynamic problem size to the layout_SFA. - template < class ProblemShape, class LayoutSFA = LayoutSF> - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); - } - - // The following function is provided for user fill dynamic problem size to the layout_SFB. - template - CUTE_HOST_DEVICE - static constexpr auto - tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::detail diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h deleted file mode 100644 index 5b1d3e5b1feb5e38ec9a57e6ee784b3e0e9b5a27..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h +++ /dev/null @@ -1,129 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for generic CUTLASS kernel. -*/ - -#pragma once - -#include // CUTLASS_HOST_DEVICE -#include // cutlass::arch::synclog_* -#include // uint64_t - -// __grid_constant__ was introduced in CUDA 11.7. -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) && !CUTLASS_CLANG_CUDA -# define CUTLASS_GRID_CONSTANT_SUPPORTED -#endif - -// __grid_constant__ can be enabled only on SM70+ -#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -# define CUTLASS_GRID_CONSTANT_ENABLED -#endif - -#if ! defined(CUTLASS_GRID_CONSTANT) -# if defined(CUTLASS_GRID_CONSTANT_ENABLED) -# define CUTLASS_GRID_CONSTANT __grid_constant__ -# else -# define CUTLASS_GRID_CONSTANT -# endif -#endif - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -template struct Type2Type { using type=T; }; -// using the simple type to replace the complex type to reduce this symbol size -template struct GetUnderlyingKernel : public Type2Type {}; -template class Wrapper > struct GetUnderlyingKernel> : public Wrapper {}; -template using GetUnderlyingKernel_t = typename GetUnderlyingKernel::type; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Generic CUTLASS kernel template. -template -CUTLASS_GLOBAL -void Kernel(typename Operator::Params params) { - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - // Declare pointer to dynamic shared memory. - typename Operator::SharedStorage *shared_storage = - reinterpret_cast(SharedStorageBase); - - Operator op; - - op(params, *shared_storage); - cutlass::arch::synclog_print(); -} - - -/// Generic CUTLASS kernel template. -template -CUTLASS_GLOBAL -void Kernel2(typename Operator::Params params) { - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - // Declare pointer to dynamic shared memory. - typename Operator::SharedStorage *shared_storage = - reinterpret_cast(SharedStorageBase); - - Operator::invoke(params, *shared_storage); - cutlass::arch::synclog_print(); - -} - - -//////////////////////////////////////////////////////////////////////////////// -// -// 3.0 specific launch -// -//////////////////////////////////////////////////////////////////////////////// - -/// Generic CUTLASS kernel template. -template -CUTLASS_GLOBAL -#ifdef __CUDACC__ -// Enclosing this in __CUDACC__ suppresses MSVC warnings. -__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) -#endif // __CUDACC__ -void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) -{ - // Dynamic shared memory base pointer - extern __shared__ char smem[]; - Operator op; - op(params, smem); - cutlass::arch::synclog_print(); - -} - -//////////////////////////////////////////////////////////////////////////////// -} /// namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp deleted file mode 100644 index 2bd817a5dd6e8cad0b4295ae1ff41d1f838eebf3..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp +++ /dev/null @@ -1,126 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include // cute::DefaultCopy -#include // cute::is_base_of_v - -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape -struct EpilogueTileAuto {}; - -// Used to let the builder pick the epilogue schedule automatically. -// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp -struct EpilogueScheduleAuto {}; - -template < - class ArchTag, - class OpClass, - class TileShape_MNK, - class ClusterShape_MNK, - class EpilogueTileType, - class ElementAccumulator, - class ElementCompute, - class ElementC, - class GmemLayoutTagC, - int AlignmentC, - class ElementD, - class GmemLayoutTagD, - int AlignmentD, - class EpilogueScheduleType, - class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, - class Enable = void -> -struct CollectiveBuilder { - static_assert(cutlass::detail::dependent_false, - "Could not build a collective epilogue for given parameters."); -}; - -// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) -namespace detail { - -// callbacks builder with operation tag -template< - class DispatchPolicy, - class FusionOp, - class TileShape_MNK, - class EpilogueTile_MN, - class ElementAccumulator, - class AccLoadOp = cute::DefaultCopy, - class = void -> -struct CallbacksBuilder { - using Callbacks = fusion::FusionCallbacks; -}; - -// callbacks builder with callbacks passthrough -template < - class DispatchPolicy, - class FusionCallbacks, - class TileShape_MNK, - class EpilogueTile_MN, - class AccLoadOp, - class ElementAccumulator -> -struct CallbacksBuilder< - DispatchPolicy, - FusionCallbacks, - TileShape_MNK, - EpilogueTile_MN, - ElementAccumulator, - AccLoadOp, - cute::enable_if_t> -> { - using Callbacks = FusionCallbacks; -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "builders/sm90_builder.inl" -#include "builders/sm100_builder.inl" -#include "builders/sm103_builder.inl" -#include "builders/sm120_builder.inl" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp deleted file mode 100644 index 918017efa4c22da5ad673fbecb55d2c7cea4d68c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ /dev/null @@ -1,75 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class DispatchPolicy, - class... Args -> -class CollectiveEpilogue { - static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "detail.hpp" - -// -// Gemm -// -#include "default_epilogue.hpp" -#include "default_epilogue_array.hpp" -#include "epilogue_tensor_broadcast.hpp" -#include "sm70_epilogue_vectorized.hpp" -#include "sm70_epilogue_vectorized_array.hpp" -#include "sm90_epilogue_tma_warpspecialized.hpp" -#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" -#include "sm90_epilogue_array_tma_warpspecialized.hpp" -#include "sm100_epilogue_nosmem.hpp" -#include "sm100_epilogue_array_nosmem.hpp" -#include "sm100_epilogue_tma_warpspecialized.hpp" -#include "sm100_epilogue_array_tma_warpspecialized.hpp" -// -// Conv -// -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp deleted file mode 100644 index ed34bc10719d2ad45d22d890e3275ce8046c5385..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp +++ /dev/null @@ -1,265 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/memory.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cute/numeric/numeric_types.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Applies an element wise operation to all elements within the fragment -/// and writes them out to destination storage. -template < - class ElementC_, - class StrideC_, - class StrideD_, - class ThreadEpilogueOp_, - class EpilogueSchedule_ -> -class DefaultEpilogue { -public: - // - // Type Aliases - // - using EpilogueSchedule = EpilogueSchedule_; - using DispatchPolicy = EpilogueSchedule_; - - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementC = ElementC_; - using StrideC = StrideC_; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - - using GmemElementC = cute::conditional_t, ElementD, ElementC>; // prevents void ref breakages - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - static const int kOutputAlignment = ThreadEpilogueOp::kCount; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - - static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - struct SharedStorage { }; - - using TensorStorage = SharedStorage; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& _, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - // Note: SharedStorage is unused for DefaultEpilogue - CUTLASS_HOST_DEVICE - DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) - : params(params_), epilogue_op(params_.thread) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - return epilogue_op.is_source_needed(); - } - - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout, - class TiledMma, - class ResidueMNK - > - CUTLASS_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, - TiledMma tiled_mma, - [[maybe_unused]] ResidueMNK, - int thread_idx, - [[maybe_unused]] char*) - { - using namespace cute; - using X = Underscore; - - static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto stride_c = detail::get_epilogue_stride(params.dC); - auto stride_d = detail::get_epilogue_stride(params.dD); - - // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - - // Partition source and destination tiles to match the accumulator partitioning - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) - - static_assert(is_static::value, "Accumulator layout must be static"); - CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), - "Source and destination must have the same number of elements."); - CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), - "Accumulator count must have the same destination element count."); - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - auto shape_MN = make_shape(M,N); - Tensor mD_crd = make_identity_tensor(shape_MN); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) - Tensor tCcD_mn = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (BLK_M,BLK_N) - Tensor tCcD = make_coord_tensor(tCcD_mn.layout()); // (VEC,THR_M,THR_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = shape_MN - cD_mn(_0{}); // (m,n) - auto residue_tCcD = shape_MN - tCcD_mn(_0{}); // (m,n) - - // Fully OOB tile - if (not elem_less(repeat_like(residue_cD, _0{}), residue_cD)) { - return; - } - - using FragCType = remove_cvref_t; - using FragDType = remove_cvref_t; - - // source is needed - if (epilogue_op.is_source_needed()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - FragCType fragC; - bool pred = elem_less(tCcD(i), residue_tCcD); - arch::global_load(fragC, &tCgC(i), pred); - FragDType fragD = epilogue_op(accumulators(i), fragC); - arch::global_store(fragD, &tCgD(i), pred); - } - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - bool pred = elem_less(tCcD(i), residue_tCcD); - FragDType fragD = epilogue_op(accumulators(i)); - arch::global_store(fragD, &tCgD(i), pred); - } - } - } - -private: - Params params; - ThreadEpilogueOp epilogue_op; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp deleted file mode 100644 index 3cab46ddcfd86ecbb2d3f1de43856f91e1002bfd..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ /dev/null @@ -1,287 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cute/numeric/numeric_types.hpp" -#include "cutlass/trace.h" - -#include "cutlass/cuda_host_adapter.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Applies an element wise operation to all elements within the fragment -// and writes them out to destination storage. -template < - class ElementC_, - class StrideC_, - class StrideD_, - class ThreadEpilogueOp_, - class EpilogueSchedule_ -> -class DefaultEpilogueArray { -public: - // - // Type Aliases - // - using EpilogueSchedule = EpilogueSchedule_; - using DispatchPolicy = EpilogueSchedule_; - - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementC = ElementC_; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - using GmemElementC = cute::conditional_t, ElementD, ElementC>; // prevents void ref breakages - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - static const int kOutputAlignment = ThreadEpilogueOp::kCount; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - - static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); - static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - struct SharedStorage { }; - - using TensorMapStorage = SharedStorage; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC{}; - ElementD** ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const&, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - DefaultEpilogueArray(Params const& params_) - : params(params_) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return true; - } - - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout, - class TiledMma, - class ResidueMNK - > - CUTLASS_HOST_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, - TiledMma tiled_mma, - [[maybe_unused]] ResidueMNK, - int thread_idx, - [[maybe_unused]] char*) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - // Batches are managed by using appropriate pointers to C and D matrices - const int32_t mock_L = 1; - const int32_t mock_l_coord = 0; - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); - - if (epilogue_op.is_source_needed() && params.dC == nullptr) { - // Beta value is non-zero while pointer to C is a nullptr - assert(0); - } - - auto [stride_c, stride_d] = [&, l = l_coord]() { - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (epilogue_op.is_source_needed()) { - return make_tuple( - detail::get_epilogue_stride(params.dC[l]), - detail::get_epilogue_stride(params.dD[l]) - ); - } - else { - return make_tuple( - InternalStrideC{}, - detail::get_epilogue_stride(params.dD[l]) - ); - } - } - else { - return make_tuple( - detail::get_epilogue_stride(params.dC), - detail::get_epilogue_stride(params.dD) - ); - } - }(); - - // Represent the full output tensor - ElementC const* ptr_C_l = nullptr; - if (epilogue_op.is_source_needed()) { - ptr_C_l = params.ptr_C[l_coord]; - } - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) - - // Partition source and destination tiles to match the accumulator partitioning - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) - - static_assert(is_static::value, "Accumulator layout must be static"); - CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), - "Source and destination must have the same number of elements."); - CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), - "Accumulator count must have the same destination element count."); - - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) - Tensor tCcD = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) - - // source is needed - if (epilogue_op.is_source_needed()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_shape(M,N))) { - tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); - } - } - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_shape(M,N))) { - tCgD(i) = epilogue_op(accumulators(i)); - } - } - } - } - -private: - Params params; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp deleted file mode 100644 index fb09f8b19475fdeeca844b20a158933726d2a895..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp +++ /dev/null @@ -1,887 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/epilogue/dispatch_policy.hpp" - -#include "cute/tensor.hpp" -#include "cute/numeric/numeric_types.hpp" -#include "cute/util/type_traits.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -namespace detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -constexpr bool -is_m_major() { - return cutlass::gemm::detail::is_major<0,Stride>(); -} - -template -constexpr bool -is_n_major() { - return cutlass::gemm::detail::is_major<1,Stride>(); -} - -template -constexpr bool -is_im2col() { - return cute::is_same_v> - || cute::is_same_v> - || cute::is_same_v>; -} - -template -struct sm90_is_ptr_array_tma : cute::false_type {}; - -template<> -struct sm90_is_ptr_array_tma : cute::true_type {}; - -template<> -struct sm90_is_ptr_array_tma : cute::true_type {}; - -template<> -struct sm90_is_ptr_array_tma : cute::true_type {}; - -template -static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma::value; - -template -struct sm90_is_ptr_array_tma_cooperative : cute::false_type {}; - -template<> -struct sm90_is_ptr_array_tma_cooperative : cute::true_type {}; - -template -static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative::value; - -template -struct sm90_is_ptr_array_tma_pingpong : cute::false_type {}; - -template<> -struct sm90_is_ptr_array_tma_pingpong : cute::true_type {}; - -template -static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong::value; - -template -struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {}; - -template< - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups -> -struct sm90_is_ptr_array_tma_dispatch_policy< - Sm90PtrArrayTmaWarpSpecialized> - : cute::true_type {}; - -template< - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups -> -struct sm90_is_ptr_array_tma_dispatch_policy< - Sm120PtrArrayTmaWarpSpecialized> - : cute::true_type {}; - -template -static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; - -using cutlass::atomic_maximum; - -template -static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; - -template -static constexpr bool sm90_is_cooperative_v = - cute::is_base_of_v || - sm90_is_ptr_array_tma_cooperative_v; - -template -static constexpr bool sm90_is_warp_specialized_v = - (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || - cute::is_base_of_v; - -template -static constexpr bool is_im2col_mode = - cute::is_same_v || - cute::is_same_v || - cute::is_same_v; - -template -struct EmptyStorage { - CUTLASS_HOST_DEVICE - T* data() { return nullptr; } -}; - -template -CUTLASS_HOST_DEVICE -auto get_epilogue_stride(Stride stride){ - if constexpr (cute::is_base_of_v|| - cute::is_base_of_v) { - return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); - } - else { - return stride; - } -} - -template -struct IsThreadEpilogueOpWithBias { - static constexpr bool value = false; - using type = typename ThreadEpilogueOp::ElementCompute; -}; - -template -struct IsThreadEpilogueOpWithBias > { - static constexpr bool value = true; - using type = typename ThreadEpilogueOp::ElementBias; -}; - -template -struct IsThreadEpilogueOpWithPerChannelScaling { - static constexpr bool value = false; -}; - -template -struct IsThreadEpilogueOpWithPerChannelScaling > { - static constexpr bool value = true; -}; - -template -struct IsThreadEpilogueOpWithResidualAdd { - static constexpr bool value = false; -}; - -template -struct IsThreadEpilogueOpWithResidualAdd > { - static constexpr bool value = ThreadEpilogueOp::IsResidualSupported; -}; - -template -struct IsThreadEpilogueOpWithActivation { - static constexpr bool value = false; - using type = void; -}; - -template -struct IsThreadEpilogueOpWithActivation > { - static constexpr bool value = true; - using type = typename ThreadEpilogueOp::ActivationFn; -}; - -template -struct IsThreadEpilogueOpWithPerChannelScaled { - static constexpr bool value = false; -}; - -template -struct IsThreadEpilogueOpWithPerChannelScaled > { - static constexpr bool value = ThreadEpilogueOp::IsPerRowScaleSupported || ThreadEpilogueOp::IsPerColScaleSupported; -}; - -template -struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; - -template -struct IsThreadEpilogueOpWithElementwiseArguments< - ThreadEpilogueOp, - cute::void_t> : cute::true_type {}; - -// Check if ActivationFn has 'Arguments' type defined -template -struct sm100_act_has_arguments : cute::false_type {}; - -template -struct sm100_act_has_arguments > : cute::true_type {}; - -template -struct Sm100EpilogueOpNumAccumulatorMtxs { - static constexpr int value = 1; -}; - -template -struct Sm100EpilogueOpNumAccumulatorMtxs> { - static constexpr int value = EpilogueOp::NumAccumulatorMtxs; -}; - - -// Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels -template -class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { -public: - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - using LoadPipeline = cutlass::PipelineTransactionAsync<0>; - using LoadPipelineState = cutlass::PipelineState<0>; - constexpr static uint32_t TmaTransactionBytes = 0; - constexpr static bool RequiresTransactionBytes = false; - - using StorePipeline = cutlass::PipelineTmaStore<0>; - using StorePipelineState = cutlass::PipelineState<0>; - - using TensorStorage = typename EpilogueOp::SharedStorage; - using TensorMapStorage = typename EpilogueOp::SharedStorage; - using PipelineStorage = typename LoadPipeline::SharedStorage; - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_load_pipe_increment(CtaTileMNK) { - return 1; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_store_pipe_increment(CtaTileMNK) { - return 1; - } - - CUTLASS_DEVICE - static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { - } - - // ctor inheritance - using EpilogueOp::EpilogueOp; - - CUTLASS_HOST_DEVICE - Sm90TmaWarpSpecializedAdapter( - typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorStorage& shared_tensors) - : EpilogueOp(params) { } - - CUTLASS_DEVICE - bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE auto - load_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) { - return cute::make_tuple(nullptr); - } - - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class TiledMma - > - CUTLASS_DEVICE auto - load( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, - [[maybe_unused]] CtaTileMNK cta_tile_mnk, - [[maybe_unused]] CtaCoordMNKL cta_coord_mnkl, - [[maybe_unused]] TiledMma tiled_mma, - [[maybe_unused]] int thread_idx, - [[maybe_unused]] TensorStorage& shared_tensors, - [[maybe_unused]] int subtile_idx=-1) - { - return load_pipe_producer_state; - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class TensorMapC - > - CUTLASS_DEVICE auto - load( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, - [[maybe_unused]] TileShapeMNK tile_shape_MNK, - [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, - [[maybe_unused]] TiledMma tiled_mma, - [[maybe_unused]] int thread_idx, - [[maybe_unused]] TensorStorage& shared_tensors, - [[maybe_unused]] TensorMapC const& load_tensormap, - [[maybe_unused]] int subtile_idx=-1, - [[maybe_unused]] bool wait = false) - { - return load_pipe_producer_state; - } - - CUTLASS_DEVICE auto - load_tail( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state) - { - return load_pipe_producer_state; - } - - CUTLASS_DEVICE auto - store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, - [[maybe_unused]] int32_t warp_group_idx) { - return cute::make_tuple(nullptr); - } - - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class AccEngine, class AccLayout, - class TiledMma - > - CUTLASS_DEVICE auto - store( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - [[maybe_unused]] StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - cute::Tensor accumulators, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - int subtile_index = -1) - { - constexpr int BLK_M_RANK = cute::rank<0>(cta_tile_mnk); - auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); - })); - - constexpr int BLK_N_RANK = cute::rank<1>(cta_tile_mnk); - auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); - })); - - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); - - (*this)( - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - reinterpret_cast(&shared_tensors)); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledMma, - class TensorMapD - > - CUTLASS_DEVICE auto - store( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - [[maybe_unused]] StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - cute::Tensor accumulators, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - [[maybe_unused]] TensorMapD const& store_tensormap, - int subtile_index = -1) - { - constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); - auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); - })); - - constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK); - auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); - })); - - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); - - (*this)( - problem_shape_mnkl, - tile_shape_MNK, - tile_coord_mnkl, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - reinterpret_cast(&shared_tensors)); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - CUTLASS_DEVICE auto - store_tail( - [[maybe_unused]] LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - [[maybe_unused]] StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state) { - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - // Dummy methods to perform different parts of TMA/Tensormap modifications - - template - CUTLASS_DEVICE - void - tensormaps_perform_update( - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] ProblemShapeMNKL problem_shape, - [[maybe_unused]] int32_t next_batch, - [[maybe_unused]] int32_t warp_group_idx) { } - - template - CUTLASS_DEVICE - void - tensormaps_cp_fence_release( - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] int32_t warp_group_idx) { } - - template - CUTLASS_DEVICE - void - tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } -}; - - -// Wrapper class to use operator-style epilogues in sm100 TMA warp-specialized kernels -template -class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { -public: - using LoadPipeline = cutlass::PipelineTransactionAsync<0>; // 0 stage to disable smem alloc - using LoadPipelineState = cutlass::PipelineState<0>; - - using StorePipeline = cutlass::PipelineTmaStore<1>; // tma store pipe has no smem alloc - using StorePipelineState = cutlass::PipelineState<1>; - - using TensorStorage = typename EpilogueOp::SharedStorage; - using TensorMapStorage = typename EpilogueOp::SharedStorage; - using PipelineStorage = typename LoadPipeline::SharedStorage; - - static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_load_pipe_increment(CtaTileMNK) { - return 1; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_store_pipe_increment(CtaTileMNK) { - return 1; - } - - CUTLASS_DEVICE - static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { - } - - CUTLASS_DEVICE - bool - is_producer_load_needed() const { - return false; - } - - // ctor inheritance - using EpilogueOp::EpilogueOp; - - CUTLASS_DEVICE auto - load_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); - } - - template< - bool ReuseTmem = false, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - TensorStorage& shared_tensors, - bool reverse_epi_n = false) - { - // C load is performed in epilogue operator - return load_pipe_producer_state; - } - - // with Tensormap - template< - bool ReuseTmem = false, - class ProblemShapeMNKL, - class CtaTileShapeMNK, - class CtaTileCoordMNKL, - class MmaTileMNK, - class TiledMma, - class TensorMap - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileShapeMNK tile_shape_mnk, - CtaTileCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - TensorStorage& shared_tensors, - [[maybe_unused]] cute::tuple const& load_tensormap_info, - bool reverse_epi_n = false) - { - // C load is performed in epilogue operator - return load_pipe_producer_state; - } - - CUTLASS_DEVICE void - load_tail( - [[maybe_unused]] LoadPipeline load_pipeline, - [[maybe_unused]] LoadPipelineState load_pipe_producer_state, - [[maybe_unused]] StorePipeline store_pipeline, - [[maybe_unused]] StorePipelineState store_pipe_producer_state) - { - } - - CUTLASS_DEVICE auto - store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); - } - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor accumulators, - TensorStorage& shared_tensors - ) - { - // Wait for mma warp to fill tmem buffer with accumulator results - acc_pipeline.consumer_wait(acc_pipe_consumer_state); - - auto [acc_state_next] = (*this).template operator()( - acc_pipeline, - acc_pipe_consumer_state, - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - accumulators, - shared_tensors); - - // Let mma warp know tmem buffer is consumed and empty - ++load_pipe_consumer_state; - ++store_pipe_producer_state; - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); - } - - // FastF32 API - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TiledCopyT2R - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TiledCopyT2R tiled_t2r) - { - (*this)( - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tTR_rAcc, - shared_tensors, - tiled_t2r); - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - // FastF32 API with Tensor Map - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TiledCopyT2R, - class TensorMap - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TensorMap tensormap, - TiledCopyT2R tiled_t2r) { - (*this)( - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tTR_rAcc, - shared_tensors, - tiled_t2r); - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class TileCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TensorMap - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - TileCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor accumulators, - TensorStorage& shared_tensors, - TensorMap tensormap - ) - { - // Wait for mma warp to fill tmem buffer with accumulator results - acc_pipeline.consumer_wait(acc_pipe_consumer_state); - - auto [acc_state_next] = (*this).template operator()( - acc_pipeline, - acc_pipe_consumer_state, - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - accumulators, - shared_tensors); - - // Let mma warp know tmem buffer is consumed and empty - ++load_pipe_consumer_state; - ++store_pipe_producer_state; - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); - } - - template - CUTLASS_DEVICE void - store_tail( - [[maybe_unused]] LoadPipeline load_pipeline, - [[maybe_unused]] LoadPipelineState load_pipe_consumer_state, - [[maybe_unused]] StorePipeline store_pipeline, - [[maybe_unused]] StorePipelineState store_pipe_producer_state, - [[maybe_unused]] CtaTileMNK cta_tile_mnk) - { - } - - // Dummy methods to perform different parts of TMA/Tensormap modifications - - template - CUTLASS_DEVICE - void - tensormaps_perform_update( - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] ProblemShape problem_shape, - [[maybe_unused]] int32_t next_batch) { } - - template - CUTLASS_DEVICE - void - tensormaps_cp_fence_release( - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] cute::TmaDescriptor const* tensormap) { } - - template - CUTLASS_DEVICE - void - tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } -}; - - -// SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. -template -struct has_beta { - static constexpr bool value = false; -}; - -template -struct has_beta> { - static constexpr bool value = true; -}; - -template -struct has_beta_ptr { - static constexpr bool value = false; -}; - -template -struct has_beta_ptr> { - static constexpr bool value = true; -}; - -template -struct has_beta_ptr_array { - static constexpr bool value = false; -}; - -template -struct has_beta_ptr_array> { - static constexpr bool value = true; -}; - -} // namespace detail -} // namespace collective -} // namespace epilogue -} // namespace cutlass diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp deleted file mode 100644 index d32dd6aeefe91b2663a9c9adeee3848e16f6c08f..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ /dev/null @@ -1,271 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Functor for performing tensor-tensor broadacasts atop existing epilogues. - - Concretely, the opeartion performed is the following: - UnaryOp( - BinaryOp1( - BinaryOp0( - Activation((alpha * A @ B) + bias), - beta * C0 - ), - beta * C1 - ) - ) - - where: - - C0 and C1 have the same extents as the output - - BinaryOp0 and BinaryOp1 perform elementwise binary operations - - UnaryOp is an elementwise operation -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Collective epilogue that applies elementwise tensor-tensor operations atop other epilogues -/// -template < - class StrideC_, - class StrideD_, - class ThreadEpilogueOp_, - class EpilogueSchedule_, - bool PerColumnBias_ = false -> -class EpilogueTensorBroadcast { -public: - // - // Type Aliases - // - using EpilogueSchedule = EpilogueSchedule_; - - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementBias = typename ThreadEpilogueOp::ElementBias; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; - - static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - - static constexpr bool IsBinaryOp0Enabled = ThreadEpilogueOp::IsBinaryOp0Enabled; - static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; - static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; - - static constexpr bool PerColumnBias = PerColumnBias_; - using BiasStride = typename cute::conditional_t, Stride<_1, _0, _0>>; - - struct SharedStorage { }; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - ElementBias* ptr_Bias = nullptr; - ElementC* ptr_C0 = nullptr; - ElementC* ptr_C1 = nullptr; - }; - - // Device side epilogue params - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& _, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - EpilogueTensorBroadcast(Params const& params_) - : params(params_), epilogue_op(params_.thread) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - return epilogue_op.is_source0_needed() || epilogue_op.is_source1_needed(); - } - - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout, - class TiledMma, - class ResidueMNK - > - CUTLASS_HOST_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, - TiledMma tiled_mma, - ResidueMNK residue_mnk, - int thread_idx, - [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto stride_c = detail::get_epilogue_stride(params.dC); - auto stride_d = detail::get_epilogue_stride(params.dD); - auto stride_bias = detail::get_epilogue_stride(BiasStride{}); - - // Represent the full output tensor - Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) - Tensor mC1_mnl = make_tensor(make_gmem_ptr(params.ptr_C1), make_shape(M,N,L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), stride_bias); // (m,n,l) - - Tensor gC0_mnl = local_tile(mC0_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gC1_mnl = local_tile(mC1_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - // Slice to get the tile this thread block is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - Tensor gC0 = gC0_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gC1 = gC1_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - - // Partition source and destination tiles to match the accumulator partitioning - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC0 = thr_mma.partition_C(gC0); // (VEC,THR_M,THR_N) - Tensor tCgC1 = thr_mma.partition_C(gC1); // (VEC,THR_M,THR_N) - Tensor tCgBias = thr_mma.partition_C(gBias); // (VEC,THR_M,THR_N) - - static_assert(is_static::value, - "Accumulator layout must be static"); - CUTE_STATIC_ASSERT_V(size(tCgC0) == size(tCgD), - "Source and destination must have the same number of elements."); - CUTE_STATIC_ASSERT_V(size(tCgC1) == size(tCgD), - "Source and destination must have the same number of elements."); - CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), - "Accumulator count must have the same destination element count."); - CUTE_STATIC_ASSERT_V(size(tCgBias) == size(accumulators), - "Accumulator count must have the same destination element count."); - - auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor tCcD = thr_mma.partition_C(cD); - - bool bias_needed = params.ptr_Bias != nullptr; - bool c0_needed = (params.ptr_C0 != nullptr) && epilogue_op.is_source0_needed(); - bool c1_needed = (params.ptr_C1 != nullptr) && epilogue_op.is_source1_needed(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - ElementBias bias = bias_needed ? tCgBias(i) : ElementBias(0); - ElementC c0 = c0_needed ? tCgC0(i) : ElementC(0); - ElementC c1 = c1_needed ? tCgC1(i) : ElementC(0); - - tCgD(i) = epilogue_op(accumulators(i), c0, c1, bias); - } - } - } - -private: - Params params; - ThreadEpilogueOp epilogue_op; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp deleted file mode 100644 index d3b2d0880e56fe65d7dda6efb982dee52f23b3e2..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ /dev/null @@ -1,937 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by Ptr-Array and Grouped GEMM epilogues. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cute/numeric/numeric_types.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Applies an element wise operation to all elements within the fragment -/// and writes it out to destination storage. -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class ThreadEpilogueOp_, - class CopyOpT2R_, - class AlignmentC_, - class AlignmentD_ -> -class CollectiveEpilogue< - Sm100PtrArrayNoSmem, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - ThreadEpilogueOp_, - CopyOpT2R_, - AlignmentC_, - AlignmentD_, - cute::enable_if_t::value> -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm100PtrArrayNoSmem; - using EpilogueTile = EpilogueTile_; - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - using CopyOpT2R = CopyOpT2R_; - using AlignmentC = AlignmentC_; - using AlignmentD = AlignmentD_; - - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - constexpr static int ThreadCount = 128; - constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; - constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - constexpr static uint32_t TmaTransactionBytes = 0; - - struct SharedStorage { - struct TensorStorage { }; - struct TensorMapStorage { }; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using TensorMapStorage = typename SharedStorage::TensorMapStorage; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC{}; - ElementD** ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout - > - CUTLASS_DEVICE auto - operator()( - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK cta_tile_shape_mnk, - TileCoordMNKL cta_coord_mnkl, - cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) - [[maybe_unused]] SharedStorage&) { - - using namespace cute; - using X = Underscore; - - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - - // Batches are managed by using appropriate pointers to C and D matrices - auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); - auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - auto cta_tiler = take<0,2>(cta_tile_shape_mnk); - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); - - auto [stride_c, stride_d] = [&, l = l_coord]() { - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (epilogue_op.is_source_needed()) { - return make_tuple( - detail::get_epilogue_stride(params.dC[l]), - detail::get_epilogue_stride(params.dD[l]) - ); - } - else { - return make_tuple( - InternalStrideC{}, - detail::get_epilogue_stride(params.dD[l]) - ); - } - } - else { - return make_tuple( - detail::get_epilogue_stride(params.dC), - detail::get_epilogue_stride(params.dD) - ); - } - }(); - - // Get the residual tensor for the current batch - ElementC const* ptr_C_l = nullptr; - if (epilogue_op.is_source_needed()) { - ptr_C_l = params.ptr_C[l_coord]; - } - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - - // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) - auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); - auto thread_idx = threadIdx.x % size(tiled_t2r); - - auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - - Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) - - Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; - constexpr int VD = cute::min(AlignmentD{}, size(mclD)); - Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); - Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); - Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); - - Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); - Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); - - CUTLASS_PRAGMA_UNROLL - for (int t = 0; t < size(tDpD); t++) { - tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); - } - - constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; - constexpr int VC = cute::min(AlignmentC{}, size(mclC)); - - Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); - Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); - Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); - - CUTLASS_PRAGMA_UNROLL - for (int t = 0; t < size(tCpC); t++) { - tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); - } - Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); - Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); - - // Detect interleaved complex fp32 kernels - [[maybe_unused]] Tensor accs = accumulators; - using ElementTmem = typename decltype(accs)::value_type; - constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; - - // 1. Load accumulators into register from tmem - // Tmem -> rmem and transformation for interleaved complex kernels - if constexpr (is_interleaved_complex_f32) { - using ElementComputeAccumulator = float; - - Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) - Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) - Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) - Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - - copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); - copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); - - // 1.1. Transform accumulators in registers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAccReal); i++) { - tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; - } - } - - // Standard tmem -> rmem epilogue - else { - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) - - copy(tiled_t2r, tTR_tAcc, tTR_rAcc); - } - - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - - // 2. Apply element-wise operation and store to gmem - // source is needed - if (epilogue_op.is_source_needed()) { - copy_if(tCpC, tTR_rC_src, tTR_rC_dst); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); i++) { - tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); - } - - copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); i++) { - tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); - } - - copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); - } - - return cute::make_tuple(acc_pipe_consumer_state); - } - - // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. - // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator - // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the - // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledCopy - > - CUTLASS_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK cta_tile_shape_mnk, - TileCoordMNKL cta_coord_mnkl, - cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) - [[maybe_unused]] SharedStorage&, - TiledCopy tiled_t2r) { - - using namespace cute; - using X = Underscore; - - static_assert(is_rmem::value, "Accumulator must be Register resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - - // Batches are managed by using appropriate pointers to C and D matrices - auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); - auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - auto cta_tiler = take<0,2>(cta_tile_shape_mnk); - - ThreadEpilogueOp epilogue_op{params.thread}; - // Get the residual tensor for the current batch - ElementC const* ptr_C_l = nullptr; - if (epilogue_op.is_source_needed()) { - ptr_C_l = params.ptr_C[l_coord]; - } - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - - - // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) - auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); - Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - - - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - // 2. Apply element-wise operation and store to gmem - // source is needed - if (epilogue_op.is_source_needed()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rGlobAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); - } - } - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rGlobAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); - } - } - } - } - -protected: - Params const& params; -}; - -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpT2R_, - class AlignmentC_, - class AlignmentD_ -> -class CollectiveEpilogue< - Sm100PtrArrayNoSmem, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpT2R_, - AlignmentC_, - AlignmentD_, - cute::enable_if_t::value> -> { -public: - // - // Type Aliases - // - // Required by the gemm::kernel - using DispatchPolicy = Sm100PtrArrayNoSmem; - using ElementC = ElementC_; - using ElementD = ElementD_; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - using StrideC = StrideC_; - using StrideD = StrideD_; - using InternalStrideC = cute::remove_pointer_t; - using InternalStrideD = cute::remove_pointer_t; - using EpilogueTile = EpilogueTile_; - using CopyOpT2R = CopyOpT2R_; - using FusionCallbacks = FusionCallbacks_; - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - -private: - constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported - || is_same_v; // alloc reduction buffer for custom EVTs - constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; - - // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. - constexpr static bool UnrollEpiLoop = - not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; - -public: - constexpr static int ThreadCount = 128; - constexpr static uint32_t TmaTransactionBytes = 0; - - struct SharedStorage { - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - array_aligned buffer; - }; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC = {}; - ElementD** ptr_D = nullptr; - StrideD dD = {}; - }; - - // Device side epilogue params - struct Params { - typename FusionCallbacks::Params thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC = {}; - ElementD** ptr_D = nullptr; - StrideD dD = {}; - }; - - // - // Constructor and Data Members - // - CUTLASS_DEVICE - CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) - : fusion_callbacks(params_.thread, shared_tensors.thread) - , smem_buffer_ptr(shared_tensors.buffer.data()) - , params(params_) {}; - -protected: - FusionCallbacks fusion_callbacks; - uint8_t* smem_buffer_ptr; - Params const& params; - -public: - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - args.ptr_C, - args.dC, - args.ptr_D, - args.dD - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { - return FusionCallbacks::get_workspace_size(problem_shape, args.thread); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - return fusion_implementable; - } - - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class AccEngine, class AccLayout - > - CUTLASS_DEVICE auto - operator()( - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - cute::Tensor accumulators, - [[maybe_unused]] SharedStorage& - ) { - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - // Wait for mma warp to fill tmem buffer with accumulator results - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); - - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - - // Batches are managed by using appropriate pointers to C and D matrices - auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); - auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - auto cta_tiler = take<0,2>(cta_tile_mnk); - auto cta_coord_mnk = cute::make_coord(m_coord, n_coord, k_coord, cute::Int<0>{}); - - bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); - - auto [stride_c, stride_d] = [&, l = l_coord]() { - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (is_C_load_needed) { - return make_tuple( - detail::get_epilogue_stride(params.dC[l]), - detail::get_epilogue_stride(params.dD[l]) - ); - } - else { - return make_tuple( - InternalStrideC{}, - detail::get_epilogue_stride(params.dD[l]) - ); - } - } - else { - return make_tuple( - detail::get_epilogue_stride(params.dC), - detail::get_epilogue_stride(params.dD) - ); - } - }(); - - // Get the residual tensor for the current batch - ElementC const* ptr_C_l = nullptr; - if (is_C_load_needed) { - ptr_C_l = params.ptr_C[l_coord]; - } - - int thread_idx = threadIdx.x % ThreadCount; - - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - - constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; - - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); - Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); - - // Construct the EVT consumer callbacks - auto residue_cD = make_coord(M,N) - cD(_0{}); - auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); - Tensor cD_ = make_coord_tensor(cD.layout()); - Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); - constexpr bool RefSrc = false; - - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); - - Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - mC, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); - - Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D[l_coord])), problem_shape_mnl, stride_d); - - Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - mD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); - - // Register Tensor - Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); - - Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); - Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - coord_cCD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); - constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; - constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); - - auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); - constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; - constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); - - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); - - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - int(0), - EpilogueTile{}, - tiled_t2r, - cD_, - residue_cD, - tTR_cD_, - residue_tTR_cD, - tCrC, - thread_idx - }; - - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // The Epilogue Loop - auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. - synchronize(); - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) - // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of - // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. - // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. - // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. - // Then, the next accumulation stage for buffer 1 can start. - [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; - static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); - - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); - - // Lambda to process a single epilogue tile - auto process_tile = [&](int epi_m, int epi_n, int iter_m, int iter_n) CUTLASS_LAMBDA_FUNC_INLINE { - bool is_last_iteration = iter_m == NumEpiSubtilesM-1 && iter_n == NumEpiSubtilesN-1; - bool do_acc_release = is_last_iteration; - - // Adjust release condition for tmem reuse - if constexpr (ReuseTmem) { - do_acc_release = iter_m == NumEpiSubtilesM-1 && iter_n == 0; // Release on first N iteration - } - - Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); - Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); - cst_callbacks.begin_loop(epi_m, epi_n); - - if constexpr (not cute::is_void_v) { - if (is_C_load_needed) { - using CVecType = uint_bit_t>; - - if constexpr (!is_same_v) { - Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); - Tensor tTR_rC_frg = recast(coalesce(tCrC)); - Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); - copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); - } - else { - auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); - Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); - Tensor c_dst = thr_g2r.retile_D(tCrC); - Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); - copy_if(tiled_g2r, c_prd, c_src, c_dst); - } - } - } - - // Copy accumulator tile from tmem to register - // The current tile in tmem - Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); - - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); - - copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); - - // After the last tmem load, signal that tmem buffer is consumed and empty - if (do_acc_release) { - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - } - - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - Tensor reduction_buffer = make_tensor( - raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); - - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); - - cst_callbacks.end_loop(epi_m, epi_n); - - using VecType = uint_bit_t>; - if constexpr (!is_same_v) { - Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); - Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); - Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); - copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); - } - else { - auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); - Tensor src = thr_r2g.retile_S(tTR_rD); - Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); - Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); - copy_if(tiled_r2g, prd, src, dst); - } - }; - - // Use static iteration with appropriate ordering - // When ReuseTmem is true and reverse_epi_n is true, we need reverse N iteration - auto n_seq = cute::make_int_sequence{}; - auto m_seq = cute::make_int_sequence{}; - - if constexpr (UnrollEpiLoop) { - // Fully unrolled static iteration - cute::for_each(n_seq, [&](auto I_N) CUTLASS_LAMBDA_FUNC_INLINE { - constexpr int iter_n = I_N; - int epi_n = iter_n; - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration - } - } - - cute::for_each(m_seq, [&](auto I_M) CUTLASS_LAMBDA_FUNC_INLINE { - constexpr int iter_m = I_M; - process_tile(iter_m, epi_n, iter_m, iter_n); - }); - }); - } else { - // Runtime loop with pragma unroll(1) - #pragma unroll 1 - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - int epi_n = iter_n; - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration - } - } - - #pragma unroll 1 - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - process_tile(iter_m, epi_n, iter_m, iter_n); - } - } - } - - cst_callbacks.end(); - }; - - // - // BEGIN EPILOGUE - // - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - epi_loop_fn(cst_callbacks); - return cute::make_tuple(acc_pipe_consumer_state); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// For sm100 kernels requiring warp specialized epilogues -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class ThreadEpilogueOp_, - class CopyOpT2R_, - class AlignmentC, - class AlignmentD -> -class CollectiveEpilogue< - Sm100PtrArrayNoSmemWarpSpecialized, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - ThreadEpilogueOp_, - CopyOpT2R_, - AlignmentC, - AlignmentD -> : public detail::Sm100TmaWarpSpecializedAdapter> -{ -public: - // ctor inheritance - using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp deleted file mode 100644 index 1f0a915d7d61411de8f1fd6158365904258bf9fd..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ /dev/null @@ -1,1526 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Functor performing elementwise operations used by Ptr-Array and Grouped Gemm epilogue. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" -#include "cutlass/detail/layout.hpp" -#include "cutlass/trace.h" - -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - class CtaTileShape_, // (CTA_M,CTA_N,CTA_K) - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpT2R_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyOpR2R_ -> -class CollectiveEpilogue< - Sm100PtrArrayTmaWarpSpecialized, - CtaTileShape_, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpT2R_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyOpR2R_ -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm100PtrArrayTmaWarpSpecialized; - using CtaTileShape = CtaTileShape_; - using EpilogueTile = EpilogueTile_; - using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - using CopyOpT2R = CopyOpT2R_; - using CopyOpG2S = CopyOpG2S_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpS2G = CopyOpS2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - using CopyOpR2R = CopyOpR2R_; - - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2S; - using GmemTiledCopyD = CopyOpS2G; - - constexpr static int ThreadCount = 128; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - -private: - - constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v; - using GmemElementD = cute::conditional_t>; - using GmemElementC = cute::conditional_t; // prevents void ref breakages - static_assert(not cute::is_void_v, "GmemElementD is void"); - - using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; - using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - static_assert(StagesC >= 1, "StagesC must be >= 1"); - static_assert(StagesD >= 1, "StagesD must be >= 1"); - - constexpr static bool ReuseSmemC = ReuseSmemC_ && is_destination_supported; - - constexpr static bool is_m_major_C = detail::is_m_major(); - constexpr static bool is_m_major_D = detail::is_m_major(); - - using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), - cute::conditional_t, Step<_1,_2>>{} )); - - constexpr static int StageCBits = cosize_v * sizeof_bits_v; - constexpr static int StageDBits = cosize_v * sizeof_bits_v; - constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); - constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; - constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; - - using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); - using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); - - constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC - && MaxStageBits % sizeof_bits_v == 0 - && MaxStageBits % sizeof_bits_v == 0; - static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); - - constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); - - // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. - constexpr static bool UnrollEpiLoop = - not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; - // TMA store delay only benefits with loop unrolling - constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; - - struct CollectiveStorageWithC { - alignas(SmemAlignmentC) ArrayEngine> smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageWithoutC { - cute::array smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageReuseC { - alignas(MaxSmemAlignment) ArrayEngine> smem_C; - alignas(MaxSmemAlignment) ArrayEngine> smem_D; - }; - -public: - // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; - constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; - constexpr static uint32_t MinTensorMapWorkspaceAlignment = 64; - - // TMA pipeline for storing D - using StorePipeline = cute::conditional_t, - cutlass::PipelineTmaStore>; - using StorePipelineState = cutlass::PipelineState; - - struct SharedStorage { - struct TensorStorage { - using CollectiveStorage = cute::conditional_t>; - CollectiveStorage collective; - - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - } tensors; - - struct TensorMapStorage : cute::aligned_struct<128, _0> { - cute::TmaDescriptor smem_tensormap_C; - cute::TmaDescriptor smem_tensormap_D; - } tensormaps; - - using PipelineStorage = typename LoadPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using TensorMapStorage = typename SharedStorage::TensorMapStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Planar complex kernels have two accumulator copies for the real and imaginary tensors. - constexpr static int NumAccumulatorMtxs = 1; - - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC{}; - ElementD** ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - struct Params { - using TensorShapeC = decltype(repeat_like(append<3>(StrideC{}, _1{}), int32_t(0))); - using TensorShapeD = decltype(repeat_like(append<3>(StrideD{}, _1{}), int32_t(0))); - using TMA_C = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - TensorShapeC{}, - append<3>(InternalStrideC{}, _0{})), - SmemLayoutStageC{}, - EpilogueTile{}, - _1{})); - using TMA_D = decltype(make_tma_copy( - CopyOpS2G{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - TensorShapeD{}, - append<3>(InternalStrideD{}, _0{})), - SmemLayoutStageD{}, - EpilogueTile{}, - _1{})); - - typename FusionCallbacks::Params thread{}; - TMA_C tma_load_c; - TMA_D tma_store_d; - cute::TmaDescriptor* tensormaps; - ElementC const** ptr_C; - StrideC dC; - ElementD** ptr_D; - StrideD dD; - }; - - // - // Gemm Host Functions - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, - void* workspace) { - // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. - // These will be replaced with correct values before the initial tma load. - auto init_M = int32_t(size<0>(CtaTileShape{})); - auto init_N = int32_t(size<1>(CtaTileShape{})); - auto init_L = 1; - - InternalStrideC stride_c; - InternalStrideD stride_d; - if constexpr (IsGroupedGemmKernel) { - // Strides for Grouped Gemm will be replaced prior to the first access regardless. - stride_c = InternalStrideC{}; - stride_d = InternalStrideD{}; - } - else { - // Tensor shapes for Ptr-Array are initialized correctly only here. - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); - init_M = get<0>(problem_shape_MNKL); - init_N = get<1>(problem_shape_MNKL); - - stride_c = args.dC; - stride_d = args.dD; - } - - typename Params::TMA_C tma_load_c{}; - if constexpr (is_source_supported) { - // Tensor pointers will be fixed before the first access - ElementC const* ptr_C_first_batch = nullptr; - Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); - tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, EpilogueTile{}, _1{}); - } - - typename Params::TMA_D tma_store_d{}; - if constexpr (is_destination_supported) { - // Tensor pointers will be fixed before the first access - ElementD* ptr_D_first_batch = nullptr; - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); - tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); - } - - auto fusion_workspace = static_cast(workspace); - auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); - auto tma_descriptor_workspace = reinterpret_cast( - static_cast(workspace) + fusion_workspace_size); - - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), - tma_load_c, - tma_store_d, - tma_descriptor_workspace, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; - constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); - // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count) + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - ProblemShape problem_shape, - [[maybe_unused]] Arguments const& args) { - bool implementable = true; - bool fusion_implementable = true; - - if (problem_shape.is_host_problem_shape_available()) { - for (int i = 0; i < problem_shape.groups(); ++i) { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; - - if constexpr (is_destination_supported) { - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); - } - - if constexpr (is_source_supported) { - constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); - } - - fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); - } - } - else { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - - bool beta_implementable = true; - - if (cute::is_void_v || args.ptr_C == nullptr) { - if constexpr (detail::has_beta::value) { - beta_implementable = args.thread.beta == 0.0; - } - if constexpr (detail::has_beta_ptr::value) { - beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; - } - if constexpr (detail::has_beta_ptr_array::value) { - beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; - } - } - - if (!beta_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); - } - - return implementable && fusion_implementable && beta_implementable; - } - - // - // Static Device Functions - // - - template - CUTLASS_DEVICE - static constexpr int - get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { - // Compute number of epilogue subtiles - return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); - } - - template - CUTLASS_DEVICE - static constexpr int - get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { - return get_load_pipe_increment(cta_tile_mnk); - } - - // - // Constructor and Data Members - // - CUTLASS_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) - : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} - -private: - Params const& params; - FusionCallbacks fusion_callbacks; - - // - // Non-static Device Functions - // -public: - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return fusion_callbacks.is_producer_load_needed(); - } - - CUTLASS_DEVICE auto - load_init( - Params const& params, - TensorMapStorage& shared_tensormap, - int32_t const sm_count, - int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = true; - auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); - return cute::make_tuple(load_tensormap); - } - - template< - bool ReuseTmem = false, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class TensorMapC - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - TensorStorage& shared_tensors, - cute::tuple load_tensormap_info, - bool reverse_epi_n = false) { - using namespace cute; - - // Check to see if tensormaps have been replaced in gmem - if (get<1>(load_tensormap_info) /* did_batch_change */) { - tensormaps_fence_acquire(get<0>(load_tensormap_info)); - } - - int lane_idx = canonical_lane_idx(); - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - - auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - - // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) - Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); - Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtile, get matching smem tensor - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - - // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) - ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) - - // Get the fusion callbacks for the producer load warp - auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - lane_idx - }; - auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Predication for TMA load (one thread issues TMA load) - bool issue_tma_load = cute::elect_one_sync(); - - // Pre-loop fusion callback entry point - pld_callbacks.begin(); - - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = size<3>(gC_epi) - 1 - iter_n; - } - } - // Acquire the lock for this stage - constexpr uint16_t mcast_mask = 0; - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - load_pipeline.producer_acquire(load_pipe_producer_state); - - // Execute the TMA load for C if needed - if (issue_tma_load && is_C_load_needed) { - copy(params.tma_load_c.with(get<0>(load_tensormap_info), *tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); - } - - // Loop fusion callback entry point - pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); - - // Commit TMA loads for this stage and release the lock - load_pipeline.producer_commit(load_pipe_producer_state); - ++load_pipe_producer_state; - } - } - - // Post-loop fusion callback entry point - pld_callbacks.end(); - - return load_pipe_producer_state; - } - - CUTLASS_DEVICE void - load_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - [[maybe_unused]] StorePipeline store_pipeline, - [[maybe_unused]] StorePipelineState store_pipe_producer_state) { - load_pipeline.producer_tail(load_pipe_producer_state); - } - - CUTLASS_DEVICE auto - store_init( - Params const& params, - TensorMapStorage& shared_tensormap, - int32_t const sm_count, - int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = false; - cute::TmaDescriptor* store_tensormap = nullptr; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - // Only the first epilogue warp needs to perform TMA related operations - if (warp_idx == 0) { - store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); - } - return cute::make_tuple(store_tensormap); - } - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TensorMapD - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor accumulators, - TensorStorage& shared_tensors, - cute::tuple store_tensormap_info - ) { - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); - static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; - - // Check to see if tensormaps have been replaced in gmem - // Only the first epilogue warp needs to perform TMA related operations - if (get<1>(store_tensormap_info) /* did_batch_change */ && warp_idx == 0) { - tensormaps_fence_acquire(get<0>(store_tensormap_info)); - } - - auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) - Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); - Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - // (t)hread-partition for (t)mem to (r)egister copy (tTR_) - TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) - - // Allocate D and accumulator registers - // Does directly store the visitor into smem. - constexpr bool IsDirectR2S = cute::is_same_v>; - using RegisterElementD = cute::conditional_t; - Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) - CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) - - // (t)hread-partition for (r)egister to (r)egister copy (tRR_) - TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); - Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - Tensor tRS_rD = [&]() { - if constexpr (!IsDirectR2S) { - return make_tensor(shape(tRS_sD(_,_,_,_0{}))); - } - else{ - return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) - } - }(); - - Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); - Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - - // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_t2r, - cD, - residue_cD, - tTR_cD, - residue_tTR_cD, - tTR_rC, - thread_idx - }; - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for sub-128 thread T2R tiled copy - Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; - constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); - bool issue_tmem_load = true; - - // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) - // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of - // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. - // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. - // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. - // Then, the next accumulation stage for buffer 1 can start. - [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; - static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); - - // Predication for TMA store (a single thread from one warp issues TMA store) - bool issue_tma_store = (warp_idx == 0) && cute::elect_one_sync(); - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - [[maybe_unused]] int epi_m_prev = 0; - [[maybe_unused]] int epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The Epilogue Loop - auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // The TMA store sequence for one epilogue loop iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - - if constexpr (is_destination_supported) { - if (issue_tma_store) { - copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; // tma_store_fn - - // Begin the wait for the producer load results - ConsumerToken load_wait_token{BarrierStatus::WaitDone}; - if (is_producer_load_needed) { - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); - } - // Begin the wait for the accumulator results - ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); - - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - bool is_first_iteration = iter_m == 0 && iter_n == 0; - bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; - bool do_acc_release = is_last_iteration; - - // Reverse subtile order for tmem reuse if necessary - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = size<3>(gD_epi) - 1 - iter_n; - } - do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; - } - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state, load_wait_token); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - // Ensure smem loads are complete before reusing smem for mixed types/layouts - if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { - synchronize(); - } - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - // Let producer load warp know smem buffers are consumed and empty - if constexpr (not ReuseSmemC) { - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - if (is_first_iteration) { - // Wait for mma warp to fill tmem buffer with accumulator results - acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); - } - - // The current tile in tmem - Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); - - // Compute tmem load predication if necessary - if constexpr (predicate_tmem_load) { - // Issue tmem load if this tile's tmem subpartition is accessible by this warp - int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; - issue_tmem_load = warp_idx == subpart_idx; - } - - // Copy accumulator tile from tmem to register - if (issue_tmem_load) { - copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); - } - - // After the last tmem load, signal that tmem buffer is consumed and empty - if (do_acc_release) { - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - } - - // Vectorized fragment loop with visitor callback entry point - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - if constexpr (!IsDirectR2S) { - // At present, only FP4 col output with scalefactor generation fusion would go into these branch - copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); - } - tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); - - // Smem reduction callback entry point using current store buffer for workspace - Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), - make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); - - // Copy output tile from register to smem - bool issue_smem_store = issue_tmem_load; - if constexpr (is_destination_supported) { - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - } - - // Post reduction, pre TMA store callback entry point - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Begin the wait for the next subtile producer load - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); - } - } // for epi_m - } // for epi_n - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - cst_callbacks.end(); - }; - - // - // BEGIN EPILOGUE - // - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - epi_loop_fn(cst_callbacks); - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); - } - - // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. - // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator - // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the - // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TiledCopyT2R, - class TensorMapD - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - TensorStorage& shared_tensors, - TensorMapD store_tensormap, - TiledCopyT2R tiled_t2r - ) { - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_rmem::value, "Accumulator must be Register resident."); - static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; - - auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) - Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); - Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - // (t)hread-partition for (t)mem to (r)egister copy (tTR_) - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) - - // Allocate D and accumulator registers - Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - - // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_t2r, - cD, - residue_cD, - tTR_cD, - residue_tTR_cD, - tTR_rC, - thread_idx - }; - - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for TMA store (one warp issues TMA store) - bool issue_tma_store = warp_idx == 0; - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - int epi_m_prev = 0, epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; - - // - // BEGIN EPILOGUE - // - - // Begin the wait for the producer load results - ConsumerToken load_wait_token{BarrierStatus::WaitDone}; - if (is_producer_load_needed) { - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); - } - - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - bool is_first_iteration = iter_m == 0 && iter_n == 0; - bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state, load_wait_token); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - // Ensure smem loads are complete before reusing smem for mixed types/layouts - if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { - synchronize(); - } - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - // Let producer load warp know smem buffers are consumed and empty - if constexpr (not ReuseSmemC) { - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - bool issue_smem_store = true; - Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) - - // Vectorized fragment loop with visitor callback entry point - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - // Smem reduction callback entry point using current store buffer for workspace - Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), - make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); - - // Copy output tile from register to smem - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - - // Post reduction, pre TMA store callback entry point - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Begin the wait for the next subtile producer load - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); - } - } // for epi_m - } // for epi_n - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - cst_callbacks.end(); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - template - CUTLASS_DEVICE void - store_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - CtaTileMNK cta_tile_mnk) { - if constexpr (ReuseSmemC) { - if (fusion_callbacks.is_producer_load_needed()) { - // wait for all TMA stores to complete - store_pipeline.producer_tail(store_pipe_producer_state); - - // Issue releases on up to StagesD-1 previously issued TMA stores - constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < release_stages; ++stage) { - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - } - } - } - - // - // Methods to perform different parts of TMA/Tensormap modifications - // - - template - CUTLASS_DEVICE auto - tensormaps_init(Params const& params, - TensorMapStorage& shared_tensormap, - int32_t const sm_count, - int32_t const sm_idx) const { - cute::TmaDescriptor* tma_desc = nullptr; - cute::TmaDescriptor* gmem_tensormap = params.tensormaps; - if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc = &gmem_tensormap[sm_idx]; - if (cute::elect_one_sync()) { - // Bringing tensormaps from params to smem for modification later - Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); - copy(recast(pC_tensormap), recast(sC_tensormap)); - } - __syncwarp(); - } - } else if constexpr (is_destination_supported) { - int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; - tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; - if (cute::elect_one_sync()) { - // Bringing tensormaps from params to smem for modification later - Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); - copy(recast(pD_tensormap), recast(sD_tensormap)); - } - __syncwarp(); - } - - return tma_desc; - } - - // Replace address for the global tensor (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_replace_global_address( - TensorMapStorage& shared_tensormap, - Params const& params, - int32_t next_batch) { - // Replacing global_address for the next batch - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - if (params.ptr_C != nullptr) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, - params.ptr_C[next_batch]); - } - } - } else if constexpr (is_destination_supported) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, - params.ptr_D[next_batch]); - } - } - - // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_replace_global_tensor_properties( - TensorMapStorage& shared_tensormaps, - Params const& params, - int32_t next_group, - ProblemShape_MNKL problem_shape_mnkl) { - const uint32_t M = get<0>(problem_shape_mnkl); - const uint32_t N = get<1>(problem_shape_mnkl); - // Replace all dims for consistency - constexpr int MaxTensorRank = 5; - cute::array prob_shape = {1,1,1,1,1}; - cute::array prob_stride = {0,0,0,0,0}; - - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - if (params.dC != nullptr) { - ElementC const* ptr_C = nullptr; - Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; - } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, - prob_shape, - prob_stride); - } - } - } - else if constexpr (is_destination_supported) { - ElementD const* ptr_D = nullptr; - Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; - } - - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D, - prob_shape, - prob_stride); - } - } - - // The entire warp must call this function collectively (that is, the instructions are aligned) - template - CUTLASS_DEVICE - void - tensormaps_perform_update( - TensorMapStorage& shared_tensormap, - Params const& params, - cute::TmaDescriptor const* tensormap, - ProblemShape problem_shape, - int32_t next_batch) { - if (cute::elect_one_sync()) { - // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormap, params, next_batch); - - if constexpr (IsGroupedGemmKernel) { - auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); - // Replacing global dims and strides for the next batch - tensormaps_replace_global_tensor_properties( - shared_tensormap, params, next_batch, problem_shape_MNKL); - } - } - // Ensure warp is converged before issuing tensormap fence release - __syncwarp(); - // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormap, tensormap); - } - - template - CUTLASS_DEVICE - void - tensormaps_cp_fence_release( - TensorMapStorage& shared_tensormap, - cute::TmaDescriptor const* tensormap) { - // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. - // This operation only happens when the group/batch changes between consecutive tiles. - // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. - auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); - } - }; - // Entire warp must do this (ie its aligned) - if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc_wait_all_fn(); - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); - } - } else if constexpr (is_destination_supported) { - tma_desc_wait_all_fn(); - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); - } - } - - template - CUTLASS_DEVICE - void - tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { - if constexpr (IsLoad) { - if (is_source_supported) { - cute::tma_descriptor_fence_acquire(tensormap); - } - } else if constexpr (is_destination_supported) { - cute::tma_descriptor_fence_acquire(tensormap); - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp deleted file mode 100644 index 90dfb80c00b7c4c48ce74d69cca52aeea8b80baa..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ /dev/null @@ -1,856 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/conv/convnd_problem_shape.hpp" -#include "cutlass/conv/detail.hpp" - -#include "cute/tensor.hpp" -#include "cute/numeric/numeric_types.hpp" -#include "cutlass/cuda_host_adapter.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -template -struct IsDefaultFusionOp { - static constexpr bool value = false; -}; - -template< - class ElementD, class ElementCompute, - class ElementC, FloatRoundStyle RoundStyle -> -struct IsDefaultFusionOp< - epilogue::fusion::LinearCombination< - ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle> -> { - static constexpr bool value = true; -}; - -template< - class ElementOutput, int Count, class ElementAccumulator, - class ElementCompute, epilogue::thread::ScaleType::Kind Scale, - FloatRoundStyle Round, class ElementSource -> -struct IsDefaultFusionOp< - epilogue::thread::LinearCombination< - ElementOutput, Count, ElementAccumulator, - ElementCompute, Scale, Round, ElementSource> -> { - static constexpr bool value = true; -}; - -// Legacy direct store sm100 epilogue using thread::LinearCombination, do not expect this to be stable -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class ThreadEpilogueOp_, - class CopyOpT2R_, - class AlignmentC_, - class AlignmentD_ -> -class CollectiveEpilogue< - Sm100NoSmem, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - ThreadEpilogueOp_, - CopyOpT2R_, - AlignmentC_, - AlignmentD_, - cute::enable_if_t::value> -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm100NoSmem; - using EpilogueTile = EpilogueTile_; - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; - using ElementC = ElementC_; - using StrideC = StrideC_; - using ElementD = ElementD_; - using StrideD = StrideD_; - using CopyOpT2R = CopyOpT2R_; - using AlignmentC = AlignmentC_; - using AlignmentD = AlignmentD_; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - constexpr static int ThreadCount = 128; - constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; - constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; - constexpr static bool isSourceNeeded = not cute::is_void_v; - - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - constexpr static uint32_t TmaTransactionBytes = 0; - - struct SharedStorage { }; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { - return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - auto shape = cute::make_shape(M,N,L); - - bool implementable = true; - implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); - if constexpr (isSourceNeeded) { - implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); - } - return implementable; - } - - // - // Constructor and Data Members - // - CUTLASS_DEVICE - CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; - -protected: - Params const& params; - - // - // Non-static Device Methods - // -public: - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout - > - CUTLASS_DEVICE auto - operator()( - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK cta_tile_shape_mnk, - TileCoordMNKL cta_coord_mnkl, - cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) - [[maybe_unused]] SharedStorage&) { - - using namespace cute; - using X = Underscore; - - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); - auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); - auto cta_tiler = take<0,2>(cta_tile_shape_mnk); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - - // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) - auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); - auto thread_idx = threadIdx.x % size(tiled_t2r); - - auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - - Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) - - Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; - constexpr int VD = cute::min(AlignmentD{}, size(mclD)); - Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); - Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); - Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); - - Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); - Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); - - CUTLASS_PRAGMA_UNROLL - for (int t = 0; t < size(tDpD); t++) { - tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); - } - - constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; - constexpr int VC = cute::min(AlignmentC{}, size(mclC)); - - Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); - Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); - Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); - - CUTLASS_PRAGMA_UNROLL - for (int t = 0; t < size(tCpC); t++) { - tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); - } - Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); - Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); - - // Detect interleaved complex fp32 kernels - [[maybe_unused]] Tensor accs = accumulators; - using ElementTmem = typename decltype(accs)::value_type; - constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; - - // 1. Load accumulators into register from tmem - // Tmem -> rmem and transformation for interleaved complex kernels - if constexpr (is_interleaved_complex_f32) { - using ElementComputeAccumulator = float; - - Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) - Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) - Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) - Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - - copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); - copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); - - // 1.1. Transform accumulators in registers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAccReal); i++) { - tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; - } - } - - // Standard tmem -> rmem epilogue - else { - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) - - copy(tiled_t2r, tTR_tAcc, tTR_rAcc); - } - - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - - // 2. Apply element-wise operation and store to gmem - ThreadEpilogueOp epilogue_op{params.thread}; - // source is needed - if (epilogue_op.is_source_needed()) { - copy_if(tCpC, tTR_rC_src, tTR_rC_dst); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); i++) { - tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); - } - - copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); i++) { - tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); - } - - copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); - } - - return cute::make_tuple(acc_pipe_consumer_state); - } - - - // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. - // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator - // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the - // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledCopy - > - CUTLASS_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK cta_tile_shape_mnk, - TileCoordMNKL cta_coord_mnkl, - cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) - [[maybe_unused]] SharedStorage&, - TiledCopy tiled_t2r) { - - using namespace cute; - using X = Underscore; - - static_assert(is_rmem::value, "Accumulator must be Register resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); - auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); - auto cta_tiler = take<0,2>(cta_tile_shape_mnk); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - - - // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) - auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); - Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - - - Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - // 2. Apply element-wise operation and store to gmem - ThreadEpilogueOp epilogue_op{params.thread}; - // source is needed - if (epilogue_op.is_source_needed()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rGlobAcc); ++i) { - if (elem_less(tTR_cCD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); - } - } - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rGlobAcc); ++i) { - if (elem_less(tTR_cCD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Direct store sm100 epilogue supporting EVT -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpT2R_, - class AlignmentC_, - class AlignmentD_ -> -class CollectiveEpilogue< - Sm100NoSmem, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpT2R_, - AlignmentC_, - AlignmentD_, - cute::enable_if_t::value> -> { -public: - // - // Type Aliases - // - // Required by the gemm::kernel - using DispatchPolicy = Sm100NoSmem; - using ElementC = ElementC_; - using ElementD = ElementD_; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - using StrideC = StrideC_; - using StrideD = StrideD_; - using EpilogueTile = EpilogueTile_; - using CopyOpT2R = CopyOpT2R_; - using FusionCallbacks = FusionCallbacks_; - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - -private: - constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported - || is_same_v; // alloc reduction buffer for custom EVTs - constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; - - // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. - constexpr static bool UnrollEpiLoop = - not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; - -public: - constexpr static int ThreadCount = 128; - constexpr static uint32_t TmaTransactionBytes = 0; - - struct SharedStorage { - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - array_aligned buffer; - }; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC = {}; - ElementD* ptr_D = nullptr; - StrideD dD = {}; - }; - - // Device side epilogue params - struct Params { - typename FusionCallbacks::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC = {}; - ElementD* ptr_D = nullptr; - StrideD dD = {}; - }; - - // - // Constructor and Data Members - // - CUTLASS_DEVICE - CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) - : fusion_callbacks(params_.thread, shared_tensors.thread) - , smem_buffer_ptr(shared_tensors.buffer.data()) - , params(params_) {}; - -protected: - FusionCallbacks fusion_callbacks; - uint8_t* smem_buffer_ptr; - Params const& params; - -public: - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - args.ptr_C, - args.dC, - args.ptr_D, - args.dD - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return FusionCallbacks::get_workspace_size(problem_shape, args.thread); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - return fusion_implementable; - } - - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class AccEngine, class AccLayout - > - CUTLASS_DEVICE auto - operator()( - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - cute::Tensor accumulators, - [[maybe_unused]] SharedStorage& - ) { - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - // Wait for mma warp to fill tmem buffer with accumulator results - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); - - auto [M, N, K, L] = problem_shape_mnkl; - auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); - auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); - auto cta_tiler = take<0,2>(cta_tile_mnk); - - int thread_idx = threadIdx.x % ThreadCount; - - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - - constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; - - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); - Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) - - Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); - - // Construct the EVT consumer callbacks - auto residue_cD = make_coord(M,N) - cD(_0{}); - auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); - Tensor cD_ = make_coord_tensor(cD.layout()); - Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); - constexpr bool RefSrc = false; - - Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); - - Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - mC, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); - - Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D)), make_shape(M,N,L), params.dD); - - Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - mD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); - - // Register Tensor - Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); - - Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); - Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( - coord_cCD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); - constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; - constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); - - auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); - constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; - constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); - - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); - - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - int(0), - EpilogueTile{}, - tiled_t2r, - cD_, - residue_cD, - tTR_cD_, - residue_tTR_cD, - tCrC, - thread_idx - }; - - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // The Epilogue Loop - auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); - - // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. - synchronize(); - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) - // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of - // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. - // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. - // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. - // Then, the next accumulation stage for buffer 1 can start. - [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; - static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); - - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - - bool is_last_iteration = iter_m == size<3>(tTR_tAcc)-1 && iter_n == size<4>(tTR_tAcc)-1; - bool do_acc_release = is_last_iteration; - - // Reverse subtile order for tmem reuse if necessary - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = size<4>(tTR_tAcc) - 1 - iter_n; - } - do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0; - } - - Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); - Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); - cst_callbacks.begin_loop(epi_m, epi_n); - - if constexpr (not cute::is_void_v) { - if (is_C_load_needed) { - using CVecType = uint_bit_t>; - - if constexpr (!is_same_v) { - Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); - Tensor tTR_rC_frg = recast(coalesce(tCrC)); - Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); - copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); - } - else { - auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); - Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); - Tensor c_dst = thr_g2r.retile_D(tCrC); - Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); - copy_if(tiled_g2r, c_prd, c_src, c_dst); - } - } - } - - // Copy accumulator tile from tmem to register - // The current tile in tmem - Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); - - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); - - copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); - - // After the last tmem load, signal that tmem buffer is consumed and empty - if (do_acc_release) { - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - } - - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - Tensor reduction_buffer = make_tensor( - raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); - - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); - - cst_callbacks.end_loop(epi_m, epi_n); - - using VecType = uint_bit_t>; - if constexpr (!is_same_v) { - Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); - Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); - Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); - copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); - } - else { - auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); - Tensor src = thr_r2g.retile_S(tTR_rD); - Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); - Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); - copy_if(tiled_r2g, prd, src, dst); - } - - } // for epi_m - } // for epi_n - - cst_callbacks.end(); - }; - - // - // BEGIN EPILOGUE - // - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - epi_loop_fn(cst_callbacks); - return cute::make_tuple(acc_pipe_consumer_state); - } - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// For sm100 kernels requiring warp specialized epilogues -template < - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class ThreadEpilogueOp_, - class CopyOpT2R_, - class AlignmentC_, - class AlignmentD_ -> -class CollectiveEpilogue< - Sm100NoSmemWarpSpecialized, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - ThreadEpilogueOp_, - CopyOpT2R_, - AlignmentC_, - AlignmentD_ -> : public detail::Sm100TmaWarpSpecializedAdapter> -{ -public: - // ctor inheritance - using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp deleted file mode 100644 index 412a4b7b747b60ebedfa26eec95a692b4d9adaf4..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp +++ /dev/null @@ -1,1299 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/conv/convnd_problem_shape.hpp" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" -#include "cutlass/detail/layout.hpp" -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/trace.h" - -#include "cutlass/conv/detail.hpp" -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - class CtaTileShape_, // (CTA_M,CTA_N,CTA_K, optional: Tile_L) - class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpT2R_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyOpR2R_ -> -class CollectiveEpilogue< - Sm100TmaWarpSpecialized, - CtaTileShape_, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpT2R_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyOpR2R_ -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm100TmaWarpSpecialized; - using CtaTileShape = CtaTileShape_; - using EpilogueTile = EpilogueTile_; - using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using StrideC = StrideC_; - using ElementD = ElementD_; - using StrideD = StrideD_; - using CopyOpT2R = CopyOpT2R_; - using CopyOpG2S = CopyOpG2S_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpS2G = CopyOpS2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - using CopyOpR2R = CopyOpR2R_; - - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2S; - using GmemTiledCopyD = CopyOpS2G; - - constexpr static int ThreadCount = 128; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - -private: - using GmemElementD = ElementD; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; - using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - static_assert(StagesC >= 1, "StagesC must be >= 1"); - static_assert(StagesD >= 1, "StagesD must be >= 1"); - - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool is_source_supported = not cute::is_void_v; - - constexpr static bool is_m_major_C = detail::is_m_major(); - constexpr static bool is_m_major_D = detail::is_m_major(); - - constexpr static bool is_im2col_C = cute::is_same_v; - constexpr static bool is_im2col_D = cute::is_same_v; - - using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), - cute::conditional_t, Step<_1,_2>>{} )); - - constexpr static int StageCBits = cosize_v * sizeof_bits_v; - constexpr static int StageDBits = cosize_v * sizeof_bits_v; - constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); - constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; - constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; - - using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); - using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); - - constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC - && MaxStageBits % sizeof_bits_v == 0 - && MaxStageBits % sizeof_bits_v == 0; - static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); - - constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); - - // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. - constexpr static bool UnrollEpiLoop = - not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; - // TMA store delay only benefits with loop unrolling - constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; - - struct CollectiveStorageWithC { - alignas(SmemAlignmentC) ArrayEngine> smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageWithoutC { - cute::array smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageReuseC { - alignas(MaxSmemAlignment) ArrayEngine> smem_C; - alignas(MaxSmemAlignment) ArrayEngine> smem_D; - }; - -public: - // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; - constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; - - // TMA pipeline for storing D - using StorePipeline = cute::conditional_t, - cutlass::PipelineTmaStore>; - using StorePipelineState = cutlass::PipelineState; - - struct SharedStorage { - struct TensorStorage { - using CollectiveStorage = cute::conditional_t>; - CollectiveStorage collective; - - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - } tensors; - - using PipelineStorage = typename LoadPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Planar complex kernels have two accumulator copies for the real and imaginary tensors. - constexpr static int NumAccumulatorMtxs = 1; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - -private: - static constexpr auto - get_tma_epi_tile() { - return cute::transform_apply(EpilogueTile{}, seq<0,1>{}, - [] (auto epi_tiler, auto mode) { - auto cta_tiler_shape = get(CtaTileShape{}); - // Use a dynamic stride to prevent mode coalescing - auto cta_tiler_stride = repeat_like(cta_tiler_shape, 0); - auto cta_tiler = make_layout(cta_tiler_shape, cta_tiler_stride); - // This is a multimodal CTA tiler, transform before returning - if constexpr (depth(cta_tiler) > 0) { - // This is an implicit multimodal tiler, match profile and return - if constexpr (tuple_size_v == 1) { - return make_tile(epi_tiler); - } - // This is an explicit multimodal tiler, compose out epi tiler - else { - return shape(composition(cta_tiler, epi_tiler)); - } - } - // This is a flat CTA tiler, no need for transformation - else { - return epi_tiler; - } - }, - [] (auto... epi_tilers) { - return make_tile(epi_tilers...); - } - ); - } - - using TmaEpilogueTile = decltype(get_tma_epi_tile()); - - template - static constexpr auto - get_tma_load_c(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), - make_layout(problem_shape_mnl, append<3>(args.dC, _0{}))); - return make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, TmaEpilogueTile{}, _1{}); - } - - template - static constexpr auto - get_tma_store_d(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), - make_layout(problem_shape_mnl, append<3>(args.dD, _0{}))); - return make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, TmaEpilogueTile{}, _1{}); - } - -public: - // Device side epilogue params - struct Params { - using TMA_C = decltype(get_tma_load_c (repeat_like(append<3>(StrideC{},_1{}), int32_t(0)), Arguments{})); - using TMA_D = decltype(get_tma_store_d(repeat_like(append<3>(StrideD{},_1{}), int32_t(0)), Arguments{})); - - typename FusionCallbacks::Params thread{}; - TMA_C tma_load_c; - TMA_D tma_store_d; - }; - - // - // Gemm Host Functions - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnl = select<0,1,3>(append<4>(problem_shape, 1)); - typename Params::TMA_C tma_load_c{}; - if constexpr (is_source_supported) { - tma_load_c = get_tma_load_c(problem_shape_mnl, args); - } - - typename Params::TMA_D tma_store_d = get_tma_store_d(problem_shape_mnl, args); - - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - tma_load_c, - tma_store_d - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return FusionCallbacks::get_workspace_size(problem_shape, args.thread); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - constexpr int tma_alignment_bits_d = cutlass::detail::get_output_alignment_bits(); - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - auto shape = cute::make_shape(M,N,L); - - bool implementable = true; - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_d / cutlass::sizeof_bits::value; - if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm - implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); - } - else { - implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); - } - - if constexpr (is_source_supported) { - constexpr int tma_alignment_bits_c = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_c / cutlass::sizeof_bits::value; - if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm - implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); - } - else { - implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); - } - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - - return implementable && fusion_implementable; - } - - // - // Conv Host Functions - // - - template - static constexpr Params - to_underlying_arguments(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return to_underlying_arguments(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace); - } - - template - static size_t - get_workspace_size(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { - return get_workspace_size(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); - } - - template - static cutlass::Status - initialize_workspace(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { - return initialize_workspace(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { - return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); - } - - // - // Static Device Functions - // - - template - CUTLASS_DEVICE - static constexpr int - get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { - // Compute number of epilogue subtiles - return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); - } - - template - CUTLASS_DEVICE - static constexpr int - get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { - return get_load_pipe_increment(cta_tile_mnk); - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE static void - prefetch_tma_descriptors(Params const& epilogue_params) { - cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); - cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); - } - - // - // Constructor and Data Members - // - CUTLASS_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) - : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} - -private: - Params const& params; - FusionCallbacks fusion_callbacks; - - // - // Non-static Device Functions - // -public: - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return fusion_callbacks.is_producer_load_needed(); - } - - template< - bool ReuseTmem = false, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - TensorStorage& shared_tensors, - bool reverse_epi_n = false) { - using namespace cute; - - int lane_idx = canonical_lane_idx(); - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - - // The tma tensor C under im2col mode only has two modes (M, N) which - // should be local tiled with only (m_coord, n_coord). - auto coord_shape = - conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); - - // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); - Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtile, get matching smem tensor - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - - // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) - ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) - - // Get the fusion callbacks for the producer load warp - auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - lane_idx - }; - auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Predication for TMA load (one thread issues TMA load) - bool issue_tma_load = cute::elect_one_sync(); - - // Pre-loop fusion callback entry point - pld_callbacks.begin(); - - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = size<3>(gC_epi) - 1 - iter_n; - } - } - // Acquire the lock for this stage - constexpr uint16_t mcast_mask = 0; - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - load_pipeline.producer_acquire(load_pipe_producer_state); - - // Execute the TMA load for C if needed - if (issue_tma_load && is_C_load_needed) { - copy(params.tma_load_c.with(*tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); - } - - // Loop fusion callback entry point - pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); - - // Commit TMA loads for this stage and release the lock - load_pipeline.producer_commit(load_pipe_producer_state); - ++load_pipe_producer_state; - } - } - - // Post-loop fusion callback entry point - pld_callbacks.end(); - - return load_pipe_producer_state; - } - - CUTLASS_DEVICE void - load_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - [[maybe_unused]] StorePipeline store_pipeline, - [[maybe_unused]] StorePipelineState store_pipe_producer_state) { - load_pipeline.producer_tail(load_pipe_producer_state); - } - - template< - bool ReuseTmem = false, - class AccumulatorPipeline, - class AccumulatorPipelineState, - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - AccumulatorPipeline acc_pipeline, - AccumulatorPipelineState acc_pipe_consumer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor accumulators, - TensorStorage& shared_tensors - ) { - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_tmem::value, "Accumulator must be TMEM resident."); - static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); - static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; - - // The tma tensor D under im2col mode only has two modes (M, N) which - // should be local tiled with only (m_coord, n_coord). - auto coord_shape = - conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); - Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - // (t)hread-partition for (t)mem to (r)egister copy (tTR_) - TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) - - // Allocate D and accumulator registers - // Does directly store the visitor into smem. - constexpr bool IsDirectR2S = cute::is_same_v>; - using RegisterElementD = cute::conditional_t; - Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) - CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) - - // (t)hread-partition for (r)egister to (r)egister copy (tRR_) - TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); - Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - Tensor tRS_rD = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (!IsDirectR2S) { - return make_tensor(shape(tRS_sD(_,_,_,_0{}))); - } - else{ - return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) - } - }(); - - Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); - Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - - // Arguments for the fusion callbacks for the consumer store warps - constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_t2r, - cD, - residue_cD, - tTR_cD, - residue_tTR_cD, - tTR_rC, - thread_idx - }; - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for sub-128 thread T2R tiled copy - Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; - constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); - bool issue_tmem_load = true; - - // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) - // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of - // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. - // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. - // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. - // Then, the next accumulation stage for buffer 1 can start. - [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; - static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); - - // Predication for TMA store (one warp issues TMA store) - bool issue_tma_store = warp_idx == 0; - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - [[maybe_unused]] int epi_m_prev = 0; - [[maybe_unused]] int epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The Epilogue Loop - auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // The TMA store sequence for one epilogue loop iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; // tma_store_fn - - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // Begin the wait for the producer load results - ConsumerToken load_wait_token{BarrierStatus::WaitDone}; - if (is_producer_load_needed) { - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); - } - // Begin the wait for the accumulator results - ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); - - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - bool is_first_iteration = iter_m == 0 && iter_n == 0; - bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; - bool do_acc_release = is_last_iteration; - - // Reverse subtile order for tmem reuse if necessary - if constexpr (ReuseTmem) { - if (reverse_epi_n) { - epi_n = size<3>(gD_epi) - 1 - iter_n; - } - do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; - } - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state, load_wait_token); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - // Ensure smem loads are complete before reusing smem for mixed types/layouts - if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { - synchronize(); - } - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - // Let producer load warp know smem buffers are consumed and empty - if constexpr (not ReuseSmemC) { - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - if (is_first_iteration) { - // Wait for mma warp to fill tmem buffer with accumulator results - acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); - } - - // The current tile in tmem - Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); - - // Compute tmem load predication if necessary - if constexpr (predicate_tmem_load) { - // Issue tmem load if this tile's tmem subpartition is accessible by this warp - int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; - issue_tmem_load = warp_idx == subpart_idx; - } - bool issue_smem_store = issue_tmem_load; - - // Copy accumulator tile from tmem to register - if (issue_tmem_load) { - copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); - } - - // After the last tmem load, signal that tmem buffer is consumed and empty - if (do_acc_release) { - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - } - - // Vectorized fragment loop with visitor callback entry point - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - if constexpr (!IsDirectR2S) { - // At present, only FP4 col output with scalefactor generation fusion would go into these branch - copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); - } - tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); - - // Smem reduction callback entry point using current store buffer for workspace - Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), - make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); - - // Copy output tile from register to smem - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - - // Post reduction, pre TMA store callback entry point - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Begin the wait for the next subtile producer load - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); - } - } // for epi_m - } // for epi_n - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - cst_callbacks.end(); - }; // epi_loop_fn - - // - // BEGIN EPILOGUE - // - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - epi_loop_fn(cst_callbacks); - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); - } - - // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. - // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator - // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the - // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. - template< - class ProblemShapeMNKL, - class CtaTileMNK, - class CtaCoordMNKL, - class MmaTileMNK, - class TiledMma, - class AccEngine, - class AccLayout, - class TiledCopyT2R - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - TensorStorage& shared_tensors, - TiledCopyT2R tiled_t2r - ) { - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_rmem::value, "Accumulator must be Register resident."); - static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; - - // The tma tensor D under im2col mode only has two modes (M, N) which - // should be local tiled with only (m_coord, n_coord). - auto coord_shape = - conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); - Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - // (t)hread-partition for (t)mem to (r)egister copy (tTR_) - ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) - - // Allocate D and accumulator registers - Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) - Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - - // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - cta_tile_mnk, - cta_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_t2r, - cD, - residue_cD, - tTR_cD, - residue_tTR_cD, - tTR_rC, - thread_idx - }; - - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for TMA store (one warp issues TMA store) - bool issue_tma_store = warp_idx == 0; - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - int epi_m_prev = 0, epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; - - // - // BEGIN EPILOGUE - // - - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // Begin the wait for the producer load results - ConsumerToken load_wait_token{BarrierStatus::WaitDone}; - if (is_producer_load_needed) { - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); - } - - // For each epilogue subtile within the CTA tile - constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); - constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) - for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { - #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) - for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { - int epi_m = iter_m, epi_n = iter_n; - bool is_first_iteration = iter_m == 0 && iter_n == 0; - bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state, load_wait_token); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - // Ensure smem loads are complete before reusing smem for mixed types/layouts - if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { - synchronize(); - } - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - // Let producer load warp know smem buffers are consumed and empty - if constexpr (not ReuseSmemC) { - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); - Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) - - // Vectorized fragment loop with visitor callback entry point - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { - tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - // Smem reduction callback entry point using current store buffer for workspace - Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), - make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); - cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); - - // Copy output tile from register to smem - bool issue_smem_store = true; - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - - // Post reduction, pre TMA store callback entry point - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Begin the wait for the next subtile producer load - load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); - } - } // for epi_m - } // for epi_n - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - cst_callbacks.end(); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - template - CUTLASS_DEVICE void - store_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - CtaTileMNK cta_tile_mnk) { - if constexpr (ReuseSmemC) { - if (fusion_callbacks.is_producer_load_needed()) { - // wait for all TMA stores to complete - store_pipeline.producer_tail(store_pipe_producer_state); - - // Issue releases on up to StagesD-1 previously issued TMA stores - constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < release_stages; ++stage) { - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - } - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp deleted file mode 100644 index c2b8d84dc92fb8b1a823135b2fdc556bce9dbebc..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ /dev/null @@ -1,549 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class StrideC, - class StrideD, - class ThreadEpilogueOp, - class SmemLayout, - class CopyAtomR2S, - class TiledCopyS2R, - class CopyAtomR2G, - class EpilogueScheduleType = EpilogueSimtVectorized, - class Enable = void -> -class Epilogue { - static_assert(cute::is_same_v || - cute::is_same_v, - "Could not find an epilogue specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Epilogue Vectorized -/// Applies an element wise operation to all elements within the fragment -/// and writes it out to destination storage. -/// -/// Ways to generalize this: -/// - CTA tile shape -/// - vectorization requirements (GMEM) -/// - vectoriz(able) transform() -/// -template < - class StrideC_, - class StrideD_, - class ThreadEpilogueOp_, - class SmemLayout_, - class CopyAtomR2S_, - class TiledCopyS2R_, - class CopyAtomR2G_, - class EpilogueScheduleType_ -> -class Epilogue< - StrideC_, - StrideD_, - ThreadEpilogueOp_, - SmemLayout_, - CopyAtomR2S_, - TiledCopyS2R_, - CopyAtomR2G_, - EpilogueScheduleType_, - cute::enable_if_t< - cute::is_same_v - > - > { -public: - // - // Type Aliases - // - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; - using SmemLayout = SmemLayout_; - using CopyAtomR2S = CopyAtomR2S_; - using TiledCopyS2R = TiledCopyS2R_; - using CopyAtomR2G = CopyAtomR2G_; - - using GmemTiledCopyC = void; - using GmemTiledCopyD = CopyAtomR2G; - - static constexpr bool IsEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; - using StrideBias = cute::conditional_t(), Stride<_1,_0,int64_t>, Stride<_0,_1,int64_t>>; - - static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - struct SharedStorage - { - cute::array_aligned> smem_epilogue; - }; - - static constexpr bool IsActHasArgs = detail::IsThreadEpilogueOpWithElementwiseArguments::value; - - // Host side epilogue arguments - template - struct ThreadEpilogueOpArguments { - ElementScalar alpha{0}; - ElementScalar beta{0}; - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias{}; - }; - - template - struct ThreadEpilogueOpArguments< - ThreadEpiOp, - cute::enable_if_t::value>> { - ElementScalar alpha{0}; - ElementScalar beta{0}; - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias{}; - typename ThreadEpiOp::ElementwiseArguments activation{}; - }; - - struct Arguments { - ThreadEpilogueOpArguments thread{}; - using StrideBias = decltype(thread.dBias); - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - template - struct ParamsType { - typename ThreadEpiOp::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - ElementBias const* ptr_Bias = nullptr; - StrideBias dBias{}; - }; - - template - struct ParamsType< - ThreadEpiOp, - cute::enable_if_t::value>> { - typename ThreadEpiOp::Params thread{}; - typename ThreadEpiOp::ElementwiseArguments activation{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - ElementBias const* ptr_Bias = nullptr; - StrideBias dBias{}; - }; - - using Params = ParamsType; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - [[maybe_unused]] ProblemShape const& _, - Arguments const& args, - [[maybe_unused]] void* workspace) { - typename ThreadEpilogueOp::Params thread_op_args; - thread_op_args.alpha = args.thread.alpha; - thread_op_args.beta = args.thread.beta; - thread_op_args.alpha_ptr = args.thread.alpha_ptr; - thread_op_args.beta_ptr = args.thread.beta_ptr; - - if constexpr (IsActHasArgs) { - return { - thread_op_args, - args.thread.activation, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD, - args.thread.bias_ptr, - args.thread.dBias - }; - } - else { - return { - thread_op_args, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD, - args.thread.bias_ptr, - args.thread.dBias - }; - } - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - Epilogue(Params const& params_) - : params(params_), epilogue_op(params_.thread) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - return epilogue_op.is_source_needed(); - } - - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout, - class TiledMma, - class ResidueMNK - > - CUTLASS_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) - TiledMma tiled_mma, - ResidueMNK residue_mnk, - int thread_idx, - char* smem_buf) { - using namespace cute; - using X = Underscore; - - static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - // synchronizing function for smem reads/writes -#if CUDA_BARRIER_ENABLED - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; -#else - auto synchronize = [] () { __syncthreads(); }; -#endif - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), params.dBias); // (m,n,l) - - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - - // Construct a tensor in SMEM that we can partition for rearranging data - SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) - - // Partition sAcc to match the accumulator partitioning - auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // Tile gD and gC by the shape of SmemLayout first - auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); - Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor gBiast = flat_divide(gBias, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - - // Partition sAcc, gC, and gD for the output - auto tiled_s2r = TiledCopyS2R{}; - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBiast); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - - // Allocate intermediate registers on the dst tensors - Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_rC = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_rBias = make_tensor_like(tSR_gBias); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - - // Repeat the D-partitioning for coordinates and predication - Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - - CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M - CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N - -#if 0 - if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { - print("aC : "); print(accumulators.layout()); print("\n"); - print("gC : "); print(gC.layout()); print("\n"); - print("gD : "); print(gD.layout()); print("\n"); - print("gBias : "); print(gBias.layout()); print("\n"); - print("sAcc : "); print(sAcc.layout()); print("\n"); - print("\n"); - print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); - print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); - print("\n"); - print("gDt : "); print(gDt.layout()); print("\n"); - print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); - print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); - print("\n"); - print("tSR_rC : "); print(tSR_rC.layout()); print("\n"); - print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); - print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); - print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); - print("\n"); - print("gBiast : "); print(gBiast.layout()); print("\n"); - print("tSR_gBias : "); print(tSR_gBias.layout()); print("\n"); - print("tSR_rBias : "); print(tSR_rBias.layout()); print("\n"); - } -#endif - - if constexpr (IsEpilogueBiasSupported) { - if (params.ptr_Bias) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); - Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); - Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); - Tensor tSR_pD_flt = cute::lazy::transform(tSR_cD_flt, [&](auto const& c){ return elem_less(c, take<0,2>(residue_mnk)); }); - - // Step 0. Copy Bias from GMEM to fragment - copy_if(tSR_pD_flt, tSR_gBias_flt, tSR_rBias_flt); - } - } - - // For each tiling needed for SmemLayout to cover shape(gD) - CUTLASS_PRAGMA_UNROLL - for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { - CUTLASS_PRAGMA_UNROLL - for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { - // Step 1. Copy to SMEM - CUTLASS_PRAGMA_UNROLL - for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { - CUTLASS_PRAGMA_UNROLL - for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { - int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; - int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; - - copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); - } - } - - // Step 2. Wait for SMEM writes to complete - synchronize(); - - // Step 3. Copy from SMEM into a fragment - copy(tiled_s2r, tSR_sAcc, tSR_rAcc); - - // Step 4. Wait for SMEM reads to complete - synchronize(); - - Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); - Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); - - if constexpr (IsEpilogueBiasSupported) { - Tensor tSR_rBiasmn = tSR_rBias(_,_,_,step_m,step_n); - - if (epilogue_op.is_source_needed()) { - // source is needed - Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); - - // Step 5. Copy C from GMEM to a fragment - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rAcc); ++i) { - tSR_rC(i,m,n) = tSR_gCmn(i,m,n); - } - } - } - } - - // Step 6. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAcc); ++i) { - if constexpr (IsActHasArgs) { - epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i), params.activation); - } else { - epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i)); - } - } - } - else { - // source is not needed, avoid load and lift compute - - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAcc); ++i) { - if constexpr (IsActHasArgs) { - epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i), params.activation); - } else { - epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i)); - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - // The Last Step. Copy to GMEM - copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); - } - } - } - } else { - if (epilogue_op.is_source_needed()) { - // source is needed - Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); - - // Step 5. Copy C from GMEM to a fragment - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rAcc); ++i) { - tSR_rC(i,m,n) = tSR_gCmn(i,m,n); - } - } - } - } - - // Step 6. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAcc); ++i) { - tSR_rD(i) = epilogue_op(tSR_rAcc(i), tSR_rC(i)); - } - } - else { - // source is not needed, avoid load and lift compute - - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAcc); ++i) { - tSR_rD(i) = epilogue_op(tSR_rAcc(i)); - } - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - // The Last Step. Copy to GMEM - copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); - } - } - } - } - } - } - } - -private: - Params params; - ThreadEpilogueOp epilogue_op; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp deleted file mode 100644 index 5030efded1e3608d91d0dca87f9f41fff827875f..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp +++ /dev/null @@ -1,412 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Ptr Array Epilogue Vectorized -/// Applies an element wise operation to all elements within the fragment -/// and writes it out to destination storage. -/// -/// Ways to generalize this: -/// - CTA tile shape -/// - vectorization requirements (GMEM) -/// - vectoriz(able) transform() -/// -template < - class StrideC_, - class StrideD_, - class ThreadEpilogueOp_, - class SmemLayout_, - class CopyAtomR2S_, - class TiledCopyS2R_, - class CopyAtomR2G_, - class EpilogueScheduleType_ -> -class Epilogue< - StrideC_, - StrideD_, - ThreadEpilogueOp_, - SmemLayout_, - CopyAtomR2S_, - TiledCopyS2R_, - CopyAtomR2G_, - EpilogueScheduleType_, - cute::enable_if_t< - cute::is_same_v - > - > { -public: - // - // Type Aliases - // - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - using SmemLayout = SmemLayout_; - using CopyAtomR2S = CopyAtomR2S_; - using TiledCopyS2R = TiledCopyS2R_; - using CopyAtomR2G = CopyAtomR2G_; - - using GmemTiledCopyC = TiledCopyS2R; - using GmemTiledCopyD = TiledCopyS2R; - - static const int kOutputAlignment = ThreadEpilogueOp::kCount; - - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - struct SharedStorage - { - cute::array_aligned> smem_epilogue; - }; - - using TensorMapStorage = SharedStorage; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC{}; - ElementD** ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const&, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - template - static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - Epilogue(Params const& params_) - : params(params_) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return true; - } - - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout, - class TiledMma, - class ResidueMNK - > - CUTLASS_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) - TiledMma tiled_mma, - ResidueMNK residue_mnk, - int thread_idx, - char* smem_buf) { - using namespace cute; - using X = Underscore; - - static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - // synchronizing function for smem reads/writes -#if CUDA_BARRIER_ENABLED - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; -#else - auto synchronize = [] () { __syncthreads(); }; -#endif - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - // Batches are managed by using appropriate pointers to C and D matrices - const int32_t mock_L = 1; - const int32_t mock_l_coord = 0; - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); - - if (epilogue_op.is_source_needed() && params.dC == nullptr) { - // Beta value is non-zero while pointer to C is a nullptr - assert(0); - } - - InternalStrideC stride_c; - InternalStrideD stride_d; - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (epilogue_op.is_source_needed()) { - stride_c = params.dC[l_coord]; - } - stride_d = params.dD[l_coord]; - } - else { - stride_c = params.dC; - stride_d = params.dD; - } - - // Represent the full output tensor - ElementC const* ptr_C_l = nullptr; - if (epilogue_op.is_source_needed()) { - ptr_C_l = params.ptr_C[l_coord]; - } - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) - - // Construct a tensor in SMEM that we can partition for rearranging data - SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) - - // Partition sAcc to match the accumulator partitioning - auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // Tile gD and gC by the shape of SmemLayout first - auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); - Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - - // Partition sAcc, gC, and gD for the output - auto tiled_s2r = TiledCopyS2R{}; - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - - // Allocate intermediate registers on the dst tensors - Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) - - // Repeat the D-partitioning for coordinates and predication - Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - - CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M - CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N - -#if 0 - if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { - print("aC : "); print(accumulators.layout()); print("\n"); - print("gC : "); print(gC.layout()); print("\n"); - print("gD : "); print(gD.layout()); print("\n"); - print("sAcc : "); print(sAcc.layout()); print("\n"); - print("\n"); - print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); - print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); - print("\n"); - print("gDt : "); print(gDt.layout()); print("\n"); - print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); - print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); - print("\n"); - print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); - print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); - print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); - print("\n"); - } -#endif - - // For each tiling needed for SmemLayout to cover shape(gD) - CUTLASS_PRAGMA_UNROLL - for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { - CUTLASS_PRAGMA_UNROLL - for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { - // Step 1. Copy to SMEM - CUTLASS_PRAGMA_UNROLL - for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { - CUTLASS_PRAGMA_UNROLL - for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { - int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; - int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; - - copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); - } - } - - // Step 2. Wait for SMEM writes to complete - synchronize(); - - // Step 3. Copy from SMEM into a fragment - copy(tiled_s2r, tSR_sAcc, tSR_rAcc); - - // Step 4. Wait for SMEM reads to complete - synchronize(); - - Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); - Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); - - if (epilogue_op.is_source_needed()) { - // source is needed - Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); - - Tensor tSR_rCmn = make_tensor(shape(tSR_gCmn)); // ((Atom,AtomNum),ATOM_M,ATOM_N) - - // Step 5. Copy C from GMEM to a fragment - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rAcc); ++i) { - tSR_rCmn(i,m,n) = tSR_gCmn(i,m,n); - } - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - // Step 6. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rAcc); ++i) { - tSR_rD(i,m,n) = epilogue_op(tSR_rAcc(i,m,n), tSR_rCmn(i,m,n)); - } - // Step 7. Copy to GMEM - copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); - } - } - } - } - else { - // source is not needed, avoid load and lift compute - - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAcc); ++i) { - tSR_rD(i) = epilogue_op(tSR_rAcc(i)); - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_gDmn); ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_gDmn); ++n) { - // Predication - if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { - // Step 6. Copy to GMEM - copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp deleted file mode 100644 index 77ef3ed2defbc2f286ac3002185a2864a8b322f8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ /dev/null @@ -1,1245 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" -#include "cutlass/detail/collective.hpp" -#include "cutlass/detail/layout.hpp" -#include "cutlass/trace.h" -#include "cutlass/cuda_host_adapter.hpp" - -#include "cute/tensor.hpp" -#include "cute/atom/copy_traits_sm90_tma.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - int NumEpilogueWarpGroups_, - class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) - class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyAtomC_, - class CopyOpR2R_ -> -class CollectiveEpilogue< - Sm90PtrArrayTmaWarpSpecialized, - CtaTileMNK_, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; - using CtaTileMNK = CtaTileMNK_; - using EpilogueTile = EpilogueTile_; - using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - using CopyOpG2S = CopyOpG2S_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpS2G = CopyOpS2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - using CopyAtomC = CopyAtomC_; - using CopyOpR2R = CopyOpR2R_; - - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2S; - using GmemTiledCopyD = CopyOpS2G; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); - static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - -private: - constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v; - using NonVoidElementD = cute::conditional_t, ElementD>; - static_assert(not cute::is_void_v, "SmemElementD is void"); - using NonVoidElementC = cute::conditional_t; // prevents void ref breakages - - using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; - using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; - - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; - constexpr static bool DelayTmaStore = DelayTmaStore_; - - constexpr static bool is_m_major_C = detail::is_m_major(); - constexpr static bool is_m_major_D = detail::is_m_major(); - - constexpr static bool is_im2col_C = cute::is_same_v; - constexpr static bool is_im2col_D = cute::is_same_v; - - // Check if register transformation is needed before copying register to shared memory. - constexpr static bool IsUseR2R = !cute::is_void_v; - - using SmemLayoutC = decltype(tile_to_shape( - SmemLayoutAtomC{}, - make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - using SmemLayoutD = decltype(tile_to_shape( - SmemLayoutAtomD{}, - make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - - constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC - && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); - static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); - constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); - - using SmemArrayTypeC = cute::ArrayEngine>; - using SmemArrayTypeD = cute::ArrayEngine>; - - using EmptyType = cute::tuple<>; - using SmemCStorage = cute::conditional_t; - using SmemDStorage = cute::conditional_t; - - struct CollectiveStorageWithC { - alignas(SmemAlignmentC) ArrayEngine> smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageWithoutC { - cute::array smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageReuseC { - alignas(MaxSmemAlignment) ArrayEngine> smem_C; - alignas(MaxSmemAlignment) ArrayEngine> smem_D; - }; - -public: - // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; - constexpr static uint32_t TmaTransactionBytes = - (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; - constexpr static bool RequiresTransactionBytes = true; - - constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; - constexpr static uint32_t MinTensorMapWorkspaceAlignment = 64; - - // TMA pipeline for storing D - using StorePipeline = cute::conditional_t, - cutlass::PipelineTmaStore>; - using StorePipelineState = cutlass::PipelineState; - - struct SharedStorage { - struct TensorStorage { - using CollectiveStorage = cute::conditional_t>; - CollectiveStorage collective; - - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - } tensors; - - struct TensorMapStorage : cute::aligned_struct<128, _0> { - cute::TmaDescriptor smem_tensormap_C; - cute::array smem_tensormap_D; - } tensormaps; - - using PipelineStorage = typename LoadPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using TensorMapStorage = typename SharedStorage::TensorMapStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const** ptr_C = nullptr; - StrideC dC; - ElementD ** ptr_D = nullptr; - StrideD dD; - }; - - // Device side epilogue params - struct Params { - using TMA_C = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), - repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}), - take<0,2>(SmemLayoutC{}), - EpilogueTile{}, - _1{})); - - using TMA_D = decltype(make_tma_copy( - CopyOpS2G{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), - repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}), - take<0,2>(SmemLayoutD{}), - EpilogueTile{}, - _1{})); - - typename FusionCallbacks::Params thread{}; - TMA_C tma_load_c; - TMA_D tma_store_d; - cute::TmaDescriptor* tensormaps; - ElementC const** ptr_C; - StrideC dC; - ElementD** ptr_D; - StrideD dD; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. - // These will be replaced with correct values before the initial tma load. - auto init_M = int32_t(size<0>(CtaTileMNK{})); - auto init_N = int32_t(size<1>(CtaTileMNK{})); - auto init_L = 1; - - static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); - - InternalStrideC stride_c; - InternalStrideD stride_d; - if constexpr (IsGroupedGemmKernel) { - // Strides for Grouped Gemm will be replaced prior to the first access regardless. - stride_c = InternalStrideC{}; - stride_d = InternalStrideD{}; - } - else { - // Tensor shapes for Ptr-Array are initialized correctly only here. - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); - init_M = get<0>(problem_shape_MNKL); - init_N = get<1>(problem_shape_MNKL); - stride_c = args.dC; - stride_d = args.dD; - } - - uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c{}; - if constexpr (is_source_supported) { - // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. - ElementC const* ptr_C_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_C) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned - Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); - tma_load_c = make_tma_copy( - CopyOpG2S{}, - tensor_c, - take<0,2>(SmemLayoutC{}), - EpilogueTile{}, - _1{}); - } - - typename Params::TMA_D tma_store_d{}; - if constexpr (is_destination_supported) { - // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. - ElementD const* ptr_D_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_D) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); - tma_store_d = make_tma_copy( - CopyOpS2G{}, - tensor_d, - take<0,2>(SmemLayoutD{}), - EpilogueTile{}, - _1{}); - } - - auto fusion_workspace = static_cast(workspace); - auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); - auto tma_descriptor_workspace = reinterpret_cast( - static_cast(workspace) + fusion_workspace_size); - - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), - tma_load_c, - tma_store_d, - tma_descriptor_workspace, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD, - transaction_bytes, - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); - auto descriptors_shape = cute::make_shape(sm_count, Int{}); - constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); - // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (size(descriptors_shape) * SizeOfCuTensorMap) + - (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - ProblemShape problem_shape, - [[maybe_unused]] Arguments const& args) { - - bool implementable = true; - bool fusion_implementable = true; - - if (problem_shape.is_host_problem_shape_available()) { - for (int i = 0; i < problem_shape.groups(); ++i) { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; - - if constexpr (is_destination_supported) { - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); - } - - if constexpr (is_source_supported) { - constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); - } - - fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); - } - } - else { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - - bool beta_implementable = true; - - if (cute::is_void_v || args.ptr_C == nullptr) { - if constexpr (detail::has_beta::value) { - beta_implementable = args.thread.beta == 0.0; - } - if constexpr (detail::has_beta_ptr::value) { - beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; - } - if constexpr (detail::has_beta_ptr_array::value) { - beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; - } - } - - if (!beta_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); - } - - return implementable && fusion_implementable && beta_implementable; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { - // Compute number of epilogue subtiles - return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { - return get_load_pipe_increment(tile_shape_MNK); - } - - CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) - : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} - - CUTLASS_DEVICE - bool - is_producer_load_needed() const { - return fusion_callbacks.is_producer_load_needed(); - } - - CUTLASS_DEVICE auto - load_init( - Params const& params, - TensorMapStorage& shared_tensormaps, - int32_t sm_count, - int32_t sm_idx) { - // Initialize tma for loading - constexpr bool IsLoad = true; - auto load_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0); - return load_tensormaps; - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class TensorMapC, - __CUTE_REQUIRES(std::is_pointer_v) - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - TensorMapC const& load_tensormap, - int subtile_idx=-1) { - using namespace cute; - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - - static_assert(!is_im2col_D, "Do not support im2col"); - - auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); - - // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) - Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); - Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtile, get matching smem tensor - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - - // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) - ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) - - // Get the fusion callbacks for the producer load warp - auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - tiled_mma, - EpilogueTile{}, - thread_idx - }; - auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - LoadPipelineState last_load_producer_state = load_pipe_producer_state; - - // Predication for TMA load (one thread issues TMA load) - bool issue_tma_load = cute::elect_one_sync(); - - // Pre-loop fusion callback entry point - pld_callbacks.begin(); - - LoadPipelineState prior_state = load_pipe_producer_state; - - bool did_load = false; - - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { - if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { - continue; - } - - // Acquire the lock for this stage - constexpr uint16_t mcast_mask = 0; - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - - load_pipeline.producer_acquire(load_pipe_producer_state); - - // Loop fusion callback entry point - pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); - - // Execute the TMA load for C if needed - if (is_C_load_needed) { - if (issue_tma_load) { - copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); - } - last_load_producer_state = load_pipe_producer_state; - did_load = true; - } - - // Commit TMA loads for this stage and release the lock - load_pipeline.producer_commit(load_pipe_producer_state); - ++load_pipe_producer_state; - } - } - - // Post-loop fusion callback entry point - pld_callbacks.end(); - - return load_pipe_producer_state; - } - - CUTLASS_DEVICE auto - load_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state) { - - if (!fusion_callbacks.is_producer_load_needed()) { - return load_pipe_producer_state; - } - - bool issue_tma_load = cute::elect_one_sync(); - if (issue_tma_load) { - load_pipeline.producer_tail(load_pipe_producer_state); - } - - return load_pipe_producer_state; - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledMma, - class TensorMapD - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - cute::Tensor accumulators, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - TensorMapD const& store_tensormap, - int subtile_idx=-1) { - - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_rmem::value, "Accumulator must be RF resident."); - static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "TileShapeMNK must be static"); - static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - - - static_assert(!is_im2col_D, "Do not support im2col"); - - auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) - - Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); - Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); - - // (t)hread-partition for (r)egister to (r)egister copy (tRR_) - TiledCopy tiled_r2r = [&]() { - if constexpr (IsUseR2R) { - return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - } - else { - return make_tiled_copy_S(Copy_Atom, - ElementCompute>{}, tiled_copy_C_atom); - } - }(); - ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = [&]() { - if constexpr (IsUseR2R) { - return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); - } - else { - return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - } - }(); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - - auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); - auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - // Allocate D registers - Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); - Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tRS_rAcc_frg = recast>(tRS_rAcc); - Tensor tRS_rD_frg = recast>(tRS_rD); - CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - - if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { - // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles - CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); - } - else { - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - } - - // Get TiledCopy for partition reference when consumer store. - TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_copy_partition_ref, - cD, - residue_cD, - tRS_cD, - residue_tRS_cD, - tRS_rC, - thread_idx - }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); - constexpr bool IsDirectR2S = cute::is_same_v>; - using RegisterElementD = cute::conditional_t; - Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - Tensor tRS_rCompute_frg = recast>(tRS_rCompute); - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for TMA store (a single thread from one warp issues TMA store) - bool issue_tma_store = ((thread_idx / NumThreadsPerWarp) == 0) && cute::elect_one_sync(); - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - int epi_m_prev = 0, epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - if constexpr (is_destination_supported) { - if (issue_tma_store) { - copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - ++issued_stores; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = issued_stores > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; - - // - // BEGIN EPILOGUE - // - - // Pre-loop fusion callback entry point - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // For each output tile - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { - bool is_first_iteration = epi_m == 0 && epi_n == 0; - bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; - - if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { - continue; - } - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - if constexpr (not ReuseSmemC) { - // Let producer load warp know smem buffers are consumed and empty - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { - // When the epilogue subtile is larger than the MMA tiles, loop over multiple - // MMA tiles - static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; - static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { - int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; - - CUTLASS_PRAGMA_UNROLL - for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { - int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); - - tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( - tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); - } - } - } - else { - int mma_m = epi_m; - int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - - // Vectorized fragment loop with visitor callback entry point - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { - tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); - } - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration and subtile_idx == -1) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - // Smem reduction callback entry point using current store buffer for workspace - cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), - synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); - - // Copy tile from register to regiser if needed - if constexpr (IsUseR2R) { - // retile source and destination for tiled_r2r - Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - - // Output needs register shuffling before copying to shared memory. - copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); - } - - // Copy tile from register to smem - if constexpr (is_destination_supported) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - - // Post reduction, pre TMA store callback entry point - constexpr bool issue_smem_store = true; // No smem store predication - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - } // for epi_m - } // for epi_n - - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - // Post-loop fusion callback entry point - cst_callbacks.end(); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - CUTLASS_DEVICE auto - store_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state) { - // wait for all TMA stores to complete - store_pipeline.producer_tail(store_pipe_producer_state); - // reset store counter - issued_stores = 0; - - if constexpr (ReuseSmemC) { - if (fusion_callbacks.is_producer_load_needed()) { - // Issue releases on up to StagesD-1 previously issued TMA stores - constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < release_stages; ++stage) { - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - } - } - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - CUTLASS_DEVICE auto - store_init( - Params const& params, - TensorMapStorage& shared_tensormaps, - int32_t sm_count, - int32_t sm_idx, - int32_t warp_group_idx) { - int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; - // Since only one warp issues TMA store, we only need that one warp to initialize tensormaps - if (warp_idx_in_warp_group == 0) { - // Initialize tma - constexpr bool IsLoad = false; - auto store_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx); - return store_tensormaps; - } - TmaDescriptor* null_tma_desc = nullptr; - return cute::make_tuple(null_tma_desc); - } - - // - // Methods to perform different parts of TMA/Tensormap modifications - // - - template - CUTLASS_DEVICE auto - tensormaps_init( - Params const& params, - TensorMapStorage& shared_tensormaps, - int32_t sm_count, - int32_t sm_idx, - int32_t warp_group_idx) { - - constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); - Layout desc_layout = make_layout(make_shape(sm_count, Int{})); - - Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) - - if constexpr (IsLoad) { - if (is_source_supported) { - constexpr int C_tensormap_index = NumEpilogueWarpGroups; - Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); - - if (cute::elect_one_sync()) { - // Bringing tensormaps from params to smem for modification later - copy(recast(pC_tensormap), recast(sC_tensormap)); - } - __syncwarp(); - return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); - - } - TmaDescriptor* null_tma_desc = nullptr; - return cute::make_tuple(null_tma_desc); - } - else { - Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{}); - - if (cute::elect_one_sync()) { - // Bringing tensormaps from params to smem for modification later - copy(recast(pD_tensormap), recast(sD_tensormap)); - } - __syncwarp(); - return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); - } - } - - // Replace address for the global tensor (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_replace_global_address( - TensorMapStorage& shared_tensormaps, - Params const& params, - int32_t next_batch, - int32_t warp_group_idx) { - // Replacing global_address for the next batch - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - if (params.ptr_C != nullptr) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, - params.ptr_C[next_batch]); - } - } - } - else if constexpr (is_destination_supported) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], - params.ptr_D[next_batch]); - } - } - - // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_replace_global_tensor_properties( - TensorMapStorage& shared_tensormaps, - Params const& params, - int32_t next_group, - ProblemShape_MNKL problem_shape_mnkl, - int32_t warp_group_idx) { - const uint32_t M = get<0>(problem_shape_mnkl); - const uint32_t N = get<1>(problem_shape_mnkl); - // Replace all dims for consistency - constexpr int MaxTensorRank = 5; - cute::array prob_shape = {1,1,1,1,1}; - cute::array prob_stride = {0,0,0,0,0}; - - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - if (params.dC != nullptr) { - ElementC const* ptr_C = nullptr; - Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; - } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, - prob_shape, - prob_stride); - } - } - } - else if constexpr (is_destination_supported) { - ElementD const* ptr_D = nullptr; - Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; - } - - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], - prob_shape, - prob_stride); - } - } - - template - CUTLASS_DEVICE - void - tensormaps_perform_update( - TensorMapStorage& shared_tensormaps, - Params const& params, - cute::TmaDescriptor const* tensormap, - ProblemShape_MNKL problem_shape_mnkl, - int32_t next_batch, - int32_t warp_group_idx) { - if (cute::elect_one_sync()) { - // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); - - if constexpr (IsGroupedGemmKernel) { - // Replacing global dims and strides for the next batch - tensormaps_replace_global_tensor_properties( - shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); - } - - } - } - - template - CUTLASS_DEVICE - void - tensormaps_cp_fence_release( - TensorMapStorage& shared_tensormaps, - cute::TmaDescriptor const* tensormap, - const int32_t warp_group_idx = 0) { - // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. - // This operation only happens when the group/batch changes between consecutive tiles. - // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. - auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); - } - }; - // Entire warp must do this (ie its aligned) - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - tma_desc_wait_all_fn(); - tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); - } - } - else if constexpr (is_destination_supported) { - tma_desc_wait_all_fn(); - tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); - } - } - - template - CUTLASS_DEVICE - void - tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { - if constexpr (IsLoad) { - if constexpr (is_source_supported) { - cute::tma_descriptor_fence_acquire(tensormap); - } - } - else { - cute::tma_descriptor_fence_acquire(tensormap); - } - } - -private: - Params const& params; - FusionCallbacks fusion_callbacks; - int issued_stores = 0; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp deleted file mode 100644 index 062b9a8b582a1a3c05407f163a0ca4b05646028a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ /dev/null @@ -1,958 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" -#include "cutlass/detail/collective.hpp" -#include "cutlass/detail/layout.hpp" -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/trace.h" - -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) - class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyAtomC_, - class CopyOpR2R_ -> -class CollectiveEpilogue< - Sm90TmaWarpSpecialized, - CtaTileMNK_, - EpilogueTile_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ -> { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm90TmaWarpSpecialized; - using CtaTileMNK = CtaTileMNK_; - using EpilogueTile = EpilogueTile_; - using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using StrideC = StrideC_; - using ElementD = ElementD_; - using StrideD = StrideD_; - using CopyOpG2S = CopyOpG2S_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpS2G = CopyOpS2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - using CopyAtomC = CopyAtomC_; - using CopyOpR2R = CopyOpR2R_; - - using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2S; - using GmemTiledCopyD = CopyOpS2G; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); - static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - -private: - constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v; - using NonVoidElementD = cute::conditional_t, ElementD>; - static_assert(not cute::is_void_v, "SmemElementD is void"); - using NonVoidElementC = cute::conditional_t; // prevents void ref breakages - - using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; - using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; - - using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; - using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; - - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; - constexpr static bool DelayTmaStore = DelayTmaStore_; - - constexpr static bool is_m_major_C = detail::is_m_major(); - constexpr static bool is_m_major_D = detail::is_m_major(); - - constexpr static bool is_im2col_C = cute::is_same_v; - constexpr static bool is_im2col_D = cute::is_same_v; - - // Check if register transformation is needed before copying register to shared memory. - constexpr static bool IsUseR2R = !cute::is_void_v; - - using SmemLayoutC = decltype(tile_to_shape( - SmemLayoutAtomC{}, - make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - using SmemLayoutD = decltype(tile_to_shape( - SmemLayoutAtomD{}, - make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - - constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC - && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); - static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); - constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); - - using SmemArrayTypeC = cute::ArrayEngine>; - using SmemArrayTypeD = cute::ArrayEngine>; - - using EmptyType = cute::tuple<>; - using SmemCStorage = cute::conditional_t; - using SmemDStorage = cute::conditional_t; - - struct CollectiveStorageWithC { - alignas(SmemAlignmentC) ArrayEngine> smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageWithoutC { - cute::array smem_C; - alignas(SmemAlignmentD) ArrayEngine> smem_D; - }; - - union CollectiveStorageReuseC { - alignas(MaxSmemAlignment) ArrayEngine> smem_C; - alignas(MaxSmemAlignment) ArrayEngine> smem_D; - }; - -public: - // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; - constexpr static uint32_t TmaTransactionBytes = - (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; - constexpr static bool RequiresTransactionBytes = true; - - // TMA pipeline for storing D - using StorePipeline = cute::conditional_t, - cutlass::PipelineTmaStore>; - using StorePipelineState = cutlass::PipelineState; - - struct SharedStorage { - struct TensorStorage { - using CollectiveStorage = cute::conditional_t>; - CollectiveStorage collective; - - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - } tensors; - - using PipelineStorage = typename LoadPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const* ptr_C; - StrideC dC; - ElementD const* ptr_D; - StrideD dD; - }; - - // Device side epilogue params - struct Params { - using TMA_C = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor(make_gmem_ptr(nullptr), - repeat_like(StrideC{}, int32_t(0)), StrideC{}), - take<0,2>(SmemLayoutC{}), - EpilogueTile{}, - _1{})); - using TMA_D = decltype(make_tma_copy( - CopyOpS2G{}, - make_tensor(make_gmem_ptr(nullptr), - repeat_like(StrideD{}, int32_t(0)), StrideD{}), - take<0,2>(SmemLayoutD{}), - EpilogueTile{}, - _1{})); - - typename FusionCallbacks::Params thread{}; - TMA_C tma_load_c; - TMA_D tma_store_d; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c{}; - if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); - tma_load_c = make_tma_copy_C_sm90( - CopyOpG2S{}, - tensor_c, - take<0,2>(SmemLayoutC{}), - EpilogueTile{}); - } - - typename Params::TMA_D tma_store_d{}; - if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); - tma_store_d = make_tma_copy_C_sm90( - CopyOpS2G{}, - tensor_d, - take<0,2>(SmemLayoutD{}), - EpilogueTile{}); - } - - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - tma_load_c, - tma_store_d, - transaction_bytes - }; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return FusionCallbacks::get_workspace_size(problem_shape, args.thread); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - auto shape = cute::make_shape(M,N,L); - - bool implementable = true; - if constexpr (is_destination_supported) { - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm - implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); - } - else { - implementable = cutlass::detail::check_alignment(shape, StrideD{}); - } - } - - if constexpr (not cute::is_void_v) { - constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm - implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); - } - else { - implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); - } - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - - bool beta_implementable = true; - - if constexpr (cute::is_void_v) { - if constexpr (detail::has_beta::value) { - beta_implementable = args.thread.beta == 0.0; - } - if constexpr (detail::has_beta_ptr::value) { - beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; - } - } - - if (!beta_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); - } - - return implementable && fusion_implementable && beta_implementable; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { - // Compute number of epilogue subtiles - return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { - return get_load_pipe_increment(tile_shape_MNK); - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void - prefetch_tma_descriptors(Params const& epilogue_params) { - if constexpr (is_source_supported) { - cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); - } - if constexpr (is_destination_supported) { - cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); - } - } - - CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) - : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} - - CUTLASS_DEVICE - bool - is_producer_load_needed() const { - return fusion_callbacks.is_producer_load_needed(); - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma - > - CUTLASS_DEVICE auto - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - int subtile_idx=-1) { - using namespace cute; - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - - // The tma tensor C under im2col mode only has two modes (M, N) which - // should be local tiled with only (m_coord, n_coord). - auto coord_shape = conditional_return( - make_coord(m_coord, n_coord), - make_coord(m_coord, n_coord, l_coord)); - - // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); - Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtile, get matching smem tensor - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - - // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) - ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) - - // Get the fusion callbacks for the producer load warp - auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - tiled_mma, - EpilogueTile{}, - thread_idx - ); - auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Predication for TMA load (one thread issues TMA load) - bool issue_tma_load = cute::elect_one_sync(); - - // Pre-loop fusion callback entry point - pld_callbacks.begin(); - - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { - if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { - continue; - } - // Acquire the lock for this stage - constexpr uint16_t mcast_mask = 0; - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - load_pipeline.producer_acquire(load_pipe_producer_state); - - // Loop fusion callback entry point - pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); - - // Execute the TMA load for C if needed - if (issue_tma_load && is_C_load_needed) { - copy(params.tma_load_c.with(*tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); - } - - // Commit TMA loads for this stage and release the lock - load_pipeline.producer_commit(load_pipe_producer_state); - ++load_pipe_producer_state; - } - } - - // Post-loop fusion callback entry point - pld_callbacks.end(); - - return load_pipe_producer_state; - } - - CUTLASS_DEVICE auto - load_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state) { - bool issue_tma_load = cute::elect_one_sync(); - if (issue_tma_load) { - load_pipeline.producer_tail(load_pipe_producer_state); - } - - return load_pipe_producer_state; - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledMma - > - CUTLASS_DEVICE auto - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - cute::Tensor accumulators, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors, - int subtile_idx=-1) { - using namespace cute; - using ElementAccumulator = typename AccEngine::value_type; - using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; - using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; - - static_assert(is_rmem::value, "Accumulator must be RF resident."); - static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "TileShapeMNK must be static"); - static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - - // The tma tensor D under im2col mode only has two modes (M, N) which - // should be local tiled with only (m_coord, n_coord). - auto coord_shape = conditional_return( - make_coord(m_coord, n_coord), - make_coord(m_coord, n_coord, l_coord)); - - // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); - Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) - - // Apply epilogue subtiling - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Construct the corresponding pipelined smem tensors - auto ptr_sC = shared_tensors.collective.smem_C.begin(); - auto ptr_sD = shared_tensors.collective.smem_D.begin(); - Tensor sC_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - - TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); - - // (t)hread-partition for (r)egister to (r)egister copy (tRR_) - TiledCopy tiled_r2r = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (IsUseR2R) { - return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - } - else { - return make_tiled_copy_S(Copy_Atom, - ElementCompute>{}, tiled_copy_C_atom); - } - }(); - ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); - - // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (IsUseR2R) { - return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); - } - else { - return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - } - }(); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - - auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); - auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - // Allocate D registers - Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); - Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - - // Vectorized fragment view - constexpr int FragmentSize = DispatchPolicy::FragmentSize; - Tensor tRS_rAcc_frg = recast>(tRS_rAcc); - Tensor tRS_rD_frg = recast>(tRS_rD); - CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); - - // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) - - // Allocate C registers - // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type - // to eliminate some redundant pack+unpack instruction sequences for sub-word types - constexpr bool IsDirectS2R = cute::is_same_v> - && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; - using RegisterElementC = cute::conditional_t; - Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) - - // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) - Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (IsUseR2R) { - // (t)hread-partition for ConsumerStoreCallbacks. - TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); - - return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - } - else { - return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - } - }(); - // Relative coordinate tensors (static) - Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - - if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { - // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles - CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); - } - else { - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - } - - // Get TiledCopy for partition reference when consumer store. - TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); - // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - tiled_mma, - EpilogueTile{}, - tiled_copy_partition_ref, - cD, - residue_cD, - tRS_cD, - residue_tRS_cD, - tRS_rC, - thread_idx - ); - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); - constexpr bool IsDirectR2S = cute::is_same_v>; - using RegisterElementD = cute::conditional_t; - Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) - Tensor tRS_rCompute_frg = recast>(tRS_rCompute); - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [&] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Predication for TMA store (one warp issues TMA store) - bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; - - // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. - // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can - // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. - // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. - LoadPipelineState load_wait_state = load_pipe_consumer_state; - if constexpr (ReuseSmemC) { - load_wait_state = store_pipe_producer_state; - load_wait_state.phase_ ^= 1; - } - - // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions - // Sync requirements of smem reuse may preclude this optimization - // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - [[maybe_unused]] int epi_m_prev = 0; - [[maybe_unused]] int epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { - // Write the tile from smem to gmem with TMA - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - synchronize(); // ensure all threads have issued their async fence - if constexpr (is_destination_supported) { - if (issue_tma_store) { - copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } - } - - // Post async fence, pre TMA commit callback entry point - cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - - // Commit the TMA stores for this stage - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - ++issued_stores; - - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // producer_acquire returns when at most StagesD-1 committed stores are pending - bool store_finished = issued_stores > StorePipeline::UnacquiredStages; - // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits - if (store_finished) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - }; - - // - // BEGIN EPILOGUE - // - - // Pre-loop fusion callback entry point - cst_callbacks.begin(); - if (cst_callbacks.begin_sync_needed()) { - synchronize(); - } - - // For each output tile - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { - [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; - bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; - - if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { - continue; - } - - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_producer_load_needed) { - // Wait for the producer load to fill smem - load_pipeline.consumer_wait(load_wait_state); - - if (is_C_load_needed) { - // Copy source tile from smem to register - copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); - // Ensure smem loads are complete before reusing smem for mixed types/layouts - if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { - synchronize(); - } - } - } - - // First loop fusion callback entry point - cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - - if (is_producer_load_needed) { - if constexpr (not ReuseSmemC) { - // Let producer load warp know smem buffers are consumed and empty - cutlass::arch::fence_view_async_shared(); - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - ++load_wait_state; - } - - if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { - // When the epilogue subtile is larger than the MMA tiles, loop over multiple - // MMA tiles - static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; - static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { - int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; - - CUTLASS_PRAGMA_UNROLL - for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { - int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); - - tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( - tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); - } - } - } - else { - int mma_m = epi_m; - int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - - // Vectorized fragment loop with visitor callback entry point - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { - tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); - } - } - - // The latest we can delay the TMA store is right before the smem store of the next iteration - // since the current TMA store needs to be committed before we can acquire the next smem buffer - if constexpr (DelayTmaStore) { - // Issue TMA stores for the previous subtile - if (not is_first_iteration and subtile_idx == -1) { - tma_store_fn(epi_m_prev, epi_n_prev); - } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - - // Smem reduction callback entry point using current store buffer for workspace - cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), - synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); - - // Copy tile from register to regiser if needed - if constexpr (IsUseR2R) { - // retile source and destination for tiled_r2r - Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) - - // Output register transformation before copying to shared memory. - copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); - } - - // Copy tile from register to smem - if constexpr (is_destination_supported) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - - // Post reduction, pre TMA store callback entry point - constexpr bool issue_smem_store = true; // No smem store predication - cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); - - if constexpr (not DelayTmaStore) { - // Issue TMA stores for this subtile - tma_store_fn(epi_m, epi_n); - } - - cst_callbacks.end_loop(epi_m, epi_n); - - } // for epi_m - } // for epi_n - - if constexpr (DelayTmaStore) { - // Issue TMA stores for the last subtile - tma_store_fn(epi_m_prev, epi_n_prev); - } - - // Post-loop fusion callback entry point - cst_callbacks.end(); - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - - CUTLASS_DEVICE auto - store_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state) { - // wait for all TMA stores to complete - store_pipeline.producer_tail(store_pipe_producer_state); - // reset store counter - issued_stores = 0; - - if constexpr (ReuseSmemC) { - if (fusion_callbacks.is_producer_load_needed()) { - // Issue releases on up to StagesD-1 previously issued TMA stores - constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < release_stages; ++stage) { - load_pipeline.consumer_release(load_pipe_consumer_state); - ++load_pipe_consumer_state; - } - } - } - - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); - } - -private: - Params const& params; - FusionCallbacks fusion_callbacks; - int issued_stores = 0; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp deleted file mode 100644 index 2d5fd85827b2751085a78dcb241aa3cf081470d5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ /dev/null @@ -1,164 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. - This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. -*/ - -#pragma once - -#include "sm90_epilogue_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) - class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyAtomC_, - class CopyOpR2R_ -> -class Sm90EpilogueTmaWarpSpecializedBiasElementwise - : public CollectiveEpilogue< - Sm90TmaWarpSpecialized, - BlockTileShape_, - EpilogueTileShape_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ -> { -private: - using Impl = - CollectiveEpilogue< - Sm90TmaWarpSpecialized, - BlockTileShape_, - EpilogueTileShape_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ - >; -public: - using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; - using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; - using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; - using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; - - // Constructor inheritance - using Impl::Impl; - - // Host side epilogue arguments - struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] - Arguments { - struct ThreadArgs { - ElementCompute alpha{1}; - ElementCompute beta{0}; - ElementCompute const *alpha_ptr{nullptr}; - ElementCompute const *beta_ptr{nullptr}; - } thread; - ElementC_ const* ptr_C{nullptr}; - StrideC_ dC{}; - ElementD_* ptr_D{nullptr}; - StrideD_ dD{}; - ElementBias const* ptr_Bias{nullptr}; - ElementT* ptr_T{nullptr}; - - CUTLASS_HOST_DEVICE - operator typename Impl::Arguments() const { - typename Impl::Arguments arguments; - arguments.thread.alpha = thread.alpha; - arguments.thread.beta = thread.beta; - arguments.thread.alpha_ptr = thread.alpha_ptr; - arguments.thread.beta_ptr = thread.beta_ptr; - if constexpr (not cute::is_void_v) { - arguments.thread.bias_ptr = ptr_Bias; - } - if constexpr (not cute::is_void_v) { - arguments.thread.aux_ptr = ptr_T; - arguments.thread.dAux = dD; - } - arguments.ptr_C = ptr_C; - arguments.dC = dC; - arguments.ptr_D = ptr_D; - arguments.dD = dD; - - return arguments; - } - }; - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp deleted file mode 100644 index ca91ac19b0aadfeddcfb030ee16f03905855cd63..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp +++ /dev/null @@ -1,302 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/scale_type.h" - -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue { - -////////////////////////////////////////////////////////////////////////////// - -////////////////////////////////////////////////////////////////////////////// -// -// Builder Epilogue Schedules -// -////////////////////////////////////////////////////////////////////////////// -// Pre-Hopper schedules -struct PtrArrayDefault {}; -struct EpilogueSimtVectorized {}; -struct EpiloguePtrArraySimtVectorized {}; -// Hopper direct store schedules -struct NoSmemWarpSpecialized {}; -struct PtrArrayNoSmemWarpSpecialized {}; -struct PtrArrayNoSmemWarpSpecializedTransposed {}; -// Hopper TMA schedules -struct TmaWarpSpecialized {}; -struct TmaWarpSpecializedCooperative {}; -struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; }; -struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; }; -struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; -// Blackwell direct store schedules -struct NoSmemWarpSpecialized1Sm {}; -struct NoSmemWarpSpecialized2Sm {}; -struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; -struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; -struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; -struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; -// Blackwell TMA schedules -struct TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2Sm {}; -struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; -struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmNvf4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmNvf4 final : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmMxf4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmMxf4 final : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmMxf8f6f4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmMxf8f6f4 final : TmaWarpSpecialized2Sm {}; -// Cooperative epilogue schedule for sm120 sparse kernels -struct SparseTmaWarpSpecializedCooperativeSm120 : public TmaWarpSpecializedCooperative {}; - -// DEPRECATED schedules, will be removed in next release -struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; -struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; -template < - template class ActivationFunctor_, - thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, - FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest -> -struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] -TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - static constexpr thread::ScaleType::Kind Scale = Scale_; - static constexpr FloatRoundStyle Round = Round_; -}; - -template < - template class ActivationFunctor_, - thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, - FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest -> -struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] -TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - static constexpr thread::ScaleType::Kind Scale = Scale_; - static constexpr FloatRoundStyle Round = Round_; -}; - -struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; -struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; - -template < - template class ActivationFunctor_, - class ElementT_, - template class BiasOp_, - bool StoreT_, - class ElementBias_ -> -struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] -TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - using ElementT = ElementT_; - - template - using BiasOp = BiasOp_; - - static constexpr bool StoreT = StoreT_; - using ElementBias = ElementBias_; -}; - -template < - template class ActivationFunctor_, - class ElementT_, - template class BiasOp_, - bool StoreT_, - class ElementBias_ -> -struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] -TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - - using ElementT = ElementT_; - - template - using BiasOp = BiasOp_; - - static constexpr bool StoreT = StoreT_; - using ElementBias = ElementBias_; -}; - -////////////////////////////////////////////////////////////////////////////// -// -// Collective Dispatch Policies -// -////////////////////////////////////////////////////////////////////////////// - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm90TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - int NumEpilogueWarpGroups_ -> -struct Sm90PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; -}; - -// DEPRECATED policies, will be removed in next release -template< - int StagesC_, - int StagesD_, - int FragmentSize_ = 2 -> -struct Sm90TmaWarpSpecializedBiasElementwise { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; -}; - - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm100TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm100PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - - static_assert(StagesC >= 1, "StagesC must be >= 1"); - static_assert(StagesD >= 1, "StagesD must be >= 1"); -}; - -struct Sm100NoSmem { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100NoSmemWarpSpecialized { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100PtrArrayNoSmem { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100PtrArrayNoSmemWarpSpecialized { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm120TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - int NumEpilogueWarpGroups_ -> -struct Sm120PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp deleted file mode 100644 index f9febeec4d92d54ec02e221d028f7329c2edeea5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp +++ /dev/null @@ -1,91 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/epilogue/fusion/operations.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Dispatch interface for epilogue fusion callbacks -// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. -// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, -// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. -template < - class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm - class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination - class CtaTile_MNK, // computed tile per CTA - class EpilogueTile_MN, // epilogue subtile size - class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) -> -struct FusionCallbacks { - static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); -}; - -// Metadata helper to handle custom EVTs or other non-FusionCallbacks types -template -struct FusionCallbacksTraits { - using DispatchPolicy = void; - using Callbacks = T; - using Operation = FusionOperation; - using CtaTile_MNK = void; - using EpilogueTile_MN = void; - using ElementCompute = void; -}; - -template < - class DispatchPolicy_, - class Operation_, - class CtaTile_MNK_, - class EpilogueTile_MN_, - class... Args -> -struct FusionCallbacksTraits< - FusionCallbacks -> { - using DispatchPolicy = DispatchPolicy_; - using Callbacks = FusionCallbacks; - using Operation = Operation_; - using CtaTile_MNK = CtaTile_MNK_; - using EpilogueTile_MN = EpilogueTile_MN_; - using ElementCompute = typename Operation::ElementCompute; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp deleted file mode 100644 index 114737a9d910a458f4895212d0904e002a9aeec8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp +++ /dev/null @@ -1,645 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include -#include -#include // cute::false_type - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Fusion Operations -// Template args must not be implementation dependent -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct FusionOperation { - // metadata types/queries that can be overrided - using ElementOutput = void; - using ElementCompute = void; - FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; - - using ElementSource = void; - static constexpr bool IsSourceSupported = false; - static constexpr bool IsResidualSupported = false; // Source is added after activation - - using ElementScalar = void; - static constexpr int AlignmentScalar = 0; - static constexpr bool IsScaleFactorSupported = false; - static constexpr bool IsPerRowScaleSupported = false; - static constexpr bool IsPerColScaleSupported = false; - - using ElementBias = void; - static constexpr int AlignmentBias = 0; - static constexpr bool IsPerRowBiasSupported = false; - static constexpr bool IsPerColBiasSupported = false; - static constexpr bool IsDePerRowBiasSupported = false; - - using ActivationFn = void; - static constexpr bool IsEltActSupported = false; - static constexpr bool IsDeEltActSupported = false; - - using ElementAux = void; - using GmemLayoutTagAux = void; - static constexpr int AlignmentAux = 0; - static constexpr bool IsAuxOutSupported = false; - static constexpr bool IsAuxInSupported = false; - - using ElementAmax = void; - static constexpr bool IsAbsMaxSupported = false; - - using ElementBlockScaleFactor = void; - static constexpr int SFVecSize = 0; - static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues - using GmemLayoutTagScalefactor = void; -}; - -// D = alpha * acc -template< - class ElementOutput_, - class ElementCompute_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledAcc : FusionOperation { - using ElementOutput = ElementOutput_; - using ElementCompute = ElementCompute_; - using ElementScalar = ElementScalar_; - static constexpr int AlignmentScalar = 1; - static constexpr auto RoundStyle = RoundStyle_; -}; - -// D = alpha * acc + beta * C -template< - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinearCombination - : ScaledAcc { - using ElementSource = ElementSource_; - static constexpr bool IsSourceSupported = true; -}; - -// D = activation(alpha * acc + beta * C) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombEltAct - : LinearCombination { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// D = softmax(top_k(alpha * acc + beta * C)) -template< - int TopK, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombTopKSoftmaxCol - : LinearCombination { -}; - - -// D = alpha * acc + beta * C + per-row bias -template< - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerRowBiasSupported = true; -}; - -// D = alpha * acc + beta * C + per-column bias -template< - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerColBiasSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltAct - : LinCombPerRowBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// Grouped Wgrad's D = alpha * acc + beta * C with special AccFetch. -template< - class GroupsPerTile_, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinearCombinationGroupedWgrad - : LinearCombination { - using GroupsPerTile = GroupsPerTile_; -}; - -// D = activation(alpha * acc + beta * C + per-column bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltAct - : LinCombPerColBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// aux = alpha * acc + beta * C + per-row bias -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltActAux - : LinCombPerRowBiasEltAct { - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-col bias) -// aux = alpha * acc + beta * C + per-col bias -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltActAux - : LinCombPerColBiasEltAct { - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerRowLinCombPerRowBiasEltAct - : LinCombPerRowBiasEltAct { - static constexpr int AlignmentScalar = AlignmentScalar_; - static constexpr bool IsPerRowScaleSupported = true; -}; - -// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerColLinCombPerColBiasEltAct - : LinCombPerColBiasEltAct { - static constexpr int AlignmentScalar = AlignmentScalar_; - static constexpr bool IsPerColScaleSupported = true; -}; - -// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerColResAddPerColBiasEltAct - : PerColLinCombPerColBiasEltAct { - static constexpr bool IsResidualSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerRowBiasEltAct - : LinCombPerRowBiasEltAct { - static constexpr bool IsScaleFactorSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerColBiasEltAct - : LinCombPerColBiasEltAct { - static constexpr bool IsScaleFactorSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementAmax_ = ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerRowBiasEltActAmaxAux - : ScaledLinCombPerRowBiasEltAct { - using ElementAmax = ElementAmax_; - static constexpr bool IsAbsMaxSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementAmax_ = ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerColBiasEltActAmaxAux - : ScaledLinCombPerColBiasEltAct { - using ElementAmax = ElementAmax_; - static constexpr bool IsAbsMaxSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// Z = Aux -// dY = alpha * acc + beta * C -// D = d_activation(dY, Z) -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombDeEltAct - : LinearCombination { - using ActivationFn = ActivationFn_; - static constexpr bool IsDeEltActSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxInSupported = true; -}; - -// Z = Aux -// dY = alpha * acc + beta * C -// D = d_activation(dY, Z) -// dBias = sum of columns of D -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombDeEltActDePerRowBias - : LinCombDeEltAct { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsDePerRowBiasSupported = true; -}; - -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombBlockScaleFactor - : LinearCombination { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - -// D = activation(alpha * acc + beta * C) -// With BlockScaleFactor generation (same recipe as LinCombBlockScaleFactor). -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombEltActBlockScaleFactor - : LinCombEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - -// D = alpha * acc + beta * C + per-row bias -// With BlockScaleFactor generation -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasBlockScaleFactor - : LinCombPerRowBias { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = alpha * acc + beta * C + per-col bias -// With BlockScaleFactor generation. -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasBlockScaleFactor - : LinCombPerColBias { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = activation(alpha * acc + beta * C + per-row bias) -// With BlockScaleFactor generation. -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltActBlockScaleFactor - : LinCombPerRowBiasEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = activation(alpha * acc + beta * C + per-col bias) -// With BlockScaleFactor generation. -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltActBlockScaleFactor - : LinCombPerColBiasEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index dfbb75bf00bd2160af770566c4f3970a2c7b5b10..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,1322 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Fusion callbacks specializations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" - -#include "cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Sm100 Tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// Sm100 direct store callbacks alias to sm100 tma callbacks with 0 stages -// Additional copy atom args will be ignored in the 0-stage specializations of aux load/store nodes -template < - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100NoSmemWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// Sm100 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -// With Row BlockScaleFactor Generation. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombRowBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C -// With Col BlockScaleFactor Generation. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombColBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// For Ptr-Array and Grouped GEMM -// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombRowBlockScaleFactorPtrArray = - Sm90EVT, // gen scalefactor - Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - // NormConst is a single device-side constant value, its not per-batch or per-group - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// For Ptr-Array and Grouped GEMM -// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombEltActRowBlockScaleFactorPtrArray = - Sm90EVT, // gen scalefactor - Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - fusion::LinCombEltActBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltActBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-row bias -// with row blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, - ElementScalar, - AlignmentBias, - RoundStyle - > -{ - - using Impl = - Sm100LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// D = alpha * acc + beta * C + per-row bias -// with col blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasColBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorColStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm100LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per_col bias -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerColBiasRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm100LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// with row blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// with col blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorColStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasEltActColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerRowBiasEltActColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per_col bias) -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerColBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// -------------------------------------------------------------------- -// Sm100PtrArrayNoSmemWarpSpecialized (direct-store, grouped GEMM) -// -------------------------------------------------------------------- -template < - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayNoSmemWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...> - : FusionCallbacks< - // reuse the ptr-array *TMA* callbacks with 0 stages - epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...> { - - using Base = FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>; - - // bring ctors into scope - using Base::Base; -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp deleted file mode 100644 index a20591288ad386543c3c7f0fd399c7fe45b7f60a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp +++ /dev/null @@ -1,500 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree compute operations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cutlass/epilogue/thread/activation.h" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BatchNormApply -// -// This node aims to do the batch norm apply. The procedure is described as follows: -// -// output = (input - mean) * inv_stddev * alpha + bias -// -// while: (1) input & output are 2 matrices with shape (M, N), -// which are frg_input & return value of the visit function -// -// (2) mean, inv_stddev, alpha & bias are 4 vectors with shape (N). -// which are loaded by ProducerLoadCallbacks -// -// To avoid redundant calculations in EVT, this node simplify the procedure as follows: -// -// output = input * alpha' + bias' -// -// while alpha' & bias' are 2 vectors with shape (N) calculated by mean, inv_stddev, alpha & bias -// -// The calculation among vectors is described as follows: -// -// alpha' = alpha * inv_stddev -// bias' = bias - mean * alpha' -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - // reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // this should just match CLC stage count - int Stages, - class CtaTileShapeMNK, - class ElementScalar, - class ElementCompute, - class ElementOutput, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BatchNormApply { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(cute::is_same_v>); // row vector broadcast for alpha, bias, mean & inv_stddev - - using SmemLayout = decltype(make_layout(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})))); - - using ElementCol = cute::conditional_t<(sizeof(ElementCompute) > sizeof(ElementScalar)), ElementCompute, ElementScalar>; - - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_alpha; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_bias; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_mean; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_inv_stddev; - }; - - struct Arguments { - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* bias_ptr = nullptr; - ElementScalar const* mean_ptr = nullptr; - ElementScalar const* inv_stddev_ptr = nullptr; - StrideMNL dVec = {}; - }; - - struct Params { - using TMA_Vec = decltype(make_tma_atom( - SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), - take<0,2>(SmemLayout{}), - take<0,2>(CtaTileShapeMNK{}))); - - TMA_Vec tma_load_alpha; - TMA_Vec tma_load_bias; - TMA_Vec tma_load_mean; - TMA_Vec tma_load_inv_stddev; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - - Tensor tensor_alpha = make_tensor(make_gmem_ptr(args.alpha_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_bias = make_tensor(make_gmem_ptr(args.bias_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_mean = make_tensor(make_gmem_ptr(args.mean_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_inv_stddev = make_tensor(make_gmem_ptr(args.inv_stddev_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - - typename Params::TMA_Vec tma_load_alpha = make_tma_atom(SM90_TMA_LOAD{}, tensor_alpha, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_bias = make_tma_atom(SM90_TMA_LOAD{}, tensor_bias, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_mean = make_tma_atom(SM90_TMA_LOAD{}, tensor_mean, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_inv_stddev = make_tma_atom(SM90_TMA_LOAD{}, tensor_inv_stddev, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - - return Params{tma_load_alpha, tma_load_bias, tma_load_mean, tma_load_inv_stddev}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BatchNormApply() { } - - CUTLASS_HOST_DEVICE - Sm100BatchNormApply(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_alpha(const_cast(shared_storage.smem_alpha.data())), - smem_bias(const_cast(shared_storage.smem_bias.data())), - smem_mean(const_cast(shared_storage.smem_mean.data())), - smem_inv_stddev(const_cast(shared_storage.smem_inv_stddev.data())), - smem_col_alpha(const_cast(shared_storage.smem_alpha.data())), - smem_col_bias(const_cast(shared_storage.smem_bias.data())) { } - - Params const* params_ptr; - ElementScalar* smem_alpha; - ElementScalar* smem_bias; - ElementScalar* smem_mean; - ElementScalar* smem_inv_stddev; - ElementCompute* smem_col_alpha; - ElementCompute* smem_col_bias; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gAlpha, GTensor&& gBias, GTensor&& gMean, GTensor&& gInvStddev, - STensor&& sAlpha, STensor&& sBias, STensor&& sMean, STensor&& sInvStddev, Params const* params_ptr) - : gAlpha(cute::forward(gAlpha)), - gBias(cute::forward(gBias)), - gMean(cute::forward(gMean)), - gInvStddev(cute::forward(gInvStddev)), - sAlpha(cute::forward(sAlpha)), - sBias(cute::forward(sBias)), - sMean(cute::forward(sMean)), - sInvStddev(cute::forward(sInvStddev)), - params_ptr(params_ptr) {} - - GTensor gAlpha; - GTensor gBias; - GTensor gMean; - GTensor gInvStddev; - - STensor sAlpha; - STensor sBias; - STensor sMean; - STensor sInvStddev; - - Params const* params_ptr; - - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - if (epi_m == 0 && epi_n == 0 && issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * bits_to_bytes(sizeof_bits_v) * 4; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - int pipe_index = (load_iteration / EpiTiles) % Stages; - copy(params_ptr->tma_load_alpha.with(*full_mbarrier_ptr), gAlpha, sAlpha(_,pipe_index)); - copy(params_ptr->tma_load_bias.with(*full_mbarrier_ptr), gBias, sBias(_,pipe_index)); - copy(params_ptr->tma_load_mean.with(*full_mbarrier_ptr), gMean, sMean(_,pipe_index)); - copy(params_ptr->tma_load_inv_stddev.with(*full_mbarrier_ptr), gInvStddev, sInvStddev(_,pipe_index)); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mAlpha = params_ptr->tma_load_alpha.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mBias = params_ptr->tma_load_bias.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mMean = params_ptr->tma_load_mean.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mInvStddev = params_ptr->tma_load_inv_stddev.get_tma_tensor(make_shape(size(M),N,size(L))); - - Tensor gAlpha = local_tile(mAlpha, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gBias = local_tile(mBias, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gMean = local_tile(mMean, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gInvStddev = local_tile(mInvStddev, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sBias = make_tensor(make_smem_ptr(smem_bias), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sMean = make_tensor(make_smem_ptr(smem_mean), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - - auto [tCgAlpha, tCsAlpha] = tma_partition(params_ptr->tma_load_alpha, group_modes<0,2>(sAlpha), group_modes<0,2>(gAlpha)); - auto [tCgBias, tCsBias] = tma_partition(params_ptr->tma_load_bias, group_modes<0,2>(sBias), group_modes<0,2>(gBias)); - auto [tCgMean, tCsMean] = tma_partition(params_ptr->tma_load_mean, group_modes<0,2>(sMean), group_modes<0,2>(gMean)); - auto [tCgInvStddev, tCsInvStddev] = tma_partition(params_ptr->tma_load_inv_stddev, group_modes<0,2>(sInvStddev), group_modes<0,2>(gInvStddev)); - - constexpr int EpiTiles = decltype(size(ceil_div(shape(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(tCgAlpha), cute::move(tCgBias), cute::move(tCgMean), cute::move(tCgInvStddev), - cute::move(tCsAlpha), cute::move(tCsBias), cute::move(tCsMean), cute::move(tCsInvStddev), params_ptr); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - SR_RTensor&& tSR_rAlpha, SR_RTensor&& tSR_rBias, - SR_RTensor&& tSR_rMean, SR_RTensor&& tSR_rInvStddev, - SR_STensor&& tSR_sAlpha, SR_STensor&& tSR_sBias, - SR_STensor&& tSR_sMean, SR_STensor&& tSR_sInvStddev, - SR_CTensor&& tSR_cAlpha, - SR_SCTensor&& tSR_sColAlpha, SR_SCTensor&& tSR_sColBias, - RTensor&& tCrAlpha, RTensor&& tCrBias, - STensor&& tCsAlpha, STensor&& tCsBias, - ThrNum thr_num, - Params const* params_ptr) - : - tSR_rAlpha(cute::forward(tSR_rAlpha)), tSR_rBias(cute::forward(tSR_rBias)), - tSR_rMean(cute::forward(tSR_rMean)), tSR_rInvStddev(cute::forward(tSR_rInvStddev)), - tSR_sAlpha(cute::forward(tSR_sAlpha)), tSR_sBias(cute::forward(tSR_sBias)), - tSR_sMean(cute::forward(tSR_sMean)), tSR_sInvStddev(cute::forward(tSR_sInvStddev)), - tSR_cAlpha(cute::forward(tSR_cAlpha)), - tSR_sColAlpha(cute::forward(tSR_sColAlpha)), tSR_sColBias(cute::forward(tSR_sColBias)), - tCrAlpha(cute::forward(tCrAlpha)), tCrBias(cute::forward(tCrBias)), - tCsAlpha(cute::forward(tCsAlpha)), tCsBias(cute::forward(tCsBias)), - thr_num(thr_num), - params_ptr(params_ptr) {} - - SR_RTensor tSR_rAlpha; - SR_RTensor tSR_rBias; - SR_RTensor tSR_rMean; - SR_RTensor tSR_rInvStddev; - SR_STensor tSR_sAlpha; - SR_STensor tSR_sBias; - SR_STensor tSR_sMean; - SR_STensor tSR_sInvStddev; - SR_CTensor tSR_cAlpha; - SR_SCTensor tSR_sColAlpha; - SR_SCTensor tSR_sColBias; - - ThrNum thr_num; - - RTensor tCrAlpha; // (CPY,CPY_M,CPY_N) - RTensor tCrBias; // (CPY,CPY_M,CPY_N) - - STensor tCsAlpha; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - STensor tCsBias; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - - Params const* params_ptr; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if (epi_m == 0 && epi_n == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - int pipe_index = (load_iteration / EpiTiles) % Stages; - - Tensor tSR_rAlpha_flt = filter_zeros(tSR_rAlpha); - Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); - Tensor tSR_rMean_flt = filter_zeros(tSR_rMean); - Tensor tSR_rInvStddev_flt = filter_zeros(tSR_rInvStddev); - Tensor tSR_sAlpha_flt = filter_zeros(tSR_sAlpha(_,_,_,pipe_index)); - Tensor tSR_sBias_flt = filter_zeros(tSR_sBias(_,_,_,pipe_index)); - Tensor tSR_sMean_flt = filter_zeros(tSR_sMean(_,_,_,pipe_index)); - Tensor tSR_sInvStddev_flt = filter_zeros(tSR_sInvStddev(_,_,_,pipe_index)); - Tensor tSR_cAlpha_flt = filter_zeros(tSR_cAlpha, tSR_rAlpha.stride()); - - for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { - if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - // OOB of SMEM - continue; - } - tSR_rAlpha_flt(i) = tSR_sAlpha_flt(i); - tSR_rBias_flt(i) = tSR_sBias_flt(i); - tSR_rMean_flt(i) = tSR_sMean_flt(i); - tSR_rInvStddev_flt(i) = tSR_sInvStddev_flt(i); - } - - constexpr int RegFragSize = cute::min(size(tSR_rAlpha_flt), cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute)))); - Tensor tSR_rAlpha_frg = recast>(tSR_rAlpha_flt); // (FRG_V) - Tensor tSR_rBias_frg = recast>(tSR_rBias_flt); // (FRG_V) - Tensor tSR_rMean_frg = recast>(tSR_rMean_flt); // (FRG_V) - Tensor tSR_rInvStddev_frg = recast>(tSR_rInvStddev_flt); // (FRG_V) - - cutlass::multiplies> mul; - cutlass::negate> negate; - cutlass::multiply_add> mul_add; - - // We do computation among vectors before computation among matrices - // alpha' = alpha * inv_stddev - // bias' = bias - alpha' * mean - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAlpha_frg); ++i) { - tSR_rAlpha_frg(i) = mul(tSR_rAlpha_frg(i), tSR_rInvStddev_frg(i)); - tSR_rBias_frg(i) = mul_add(tSR_rAlpha_frg(i), negate(tSR_rMean_frg(i)), tSR_rBias_frg(i)); - } - - Tensor tSR_sColAlpha_flt = filter_zeros(tSR_sColAlpha(_,_,_,pipe_index)); - Tensor tSR_sColBias_flt = filter_zeros(tSR_sColBias(_,_,_,pipe_index)); - // After computation, 4 vectors -> 2 vectors - for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { - if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - // OOB of SMEM - continue; - } - tSR_sColAlpha_flt(i) = tSR_rAlpha_flt(i); - tSR_sColBias_flt(i) = tSR_rBias_flt(i); - } - - synchronize(); - - // To do bn_apply with Acc, reload these 2 vectors with the consistent shape - copy_aligned(tCsAlpha(_,_,_,_,_,pipe_index), tCrAlpha); - copy_aligned(tCsBias(_,_,_,_,_,pipe_index), tCrBias); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_inputs) { - constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); - cutlass::multiply_add> mul_add; - - Array frg_apply; - - using ConvertInput = NumericArrayConverter; - using ConvertOutput = NumericArrayConverter; - - ConvertInput convert_input{}; - ConvertOutput convert_output{}; - - Array frg_I = convert_input(frg_inputs); - - Tensor tCrAlpha_frg = recast>(tCrAlpha(_,_,_,epi_m,epi_n)); - Tensor tCrBias_frg = recast>(tCrBias(_,_,_,epi_m,epi_n)); - - constexpr int RegFragArraySize = FragmentSize / RegFragSize; - using RegFragArr = Array, RegFragArraySize>; - RegFragArr& frg_I_ = reinterpret_cast(frg_I); - RegFragArr& frg_apply_ = reinterpret_cast(frg_apply); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < RegFragArraySize; ++i) { - frg_apply_[i] = mul_add(tCrAlpha_frg(epi_v * RegFragArraySize + i), frg_I_[i], tCrBias_frg(epi_v * RegFragArraySize + i)); - } - - return convert_output(frg_apply); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - using ThreadCount = decltype(size(args.tiled_copy)); - - Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sBias = make_tensor(make_smem_ptr(smem_bias), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sColAlpha = make_tensor(make_smem_ptr(smem_col_alpha), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sColBias = make_tensor(make_smem_ptr(smem_col_bias), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sMean = make_tensor(make_smem_ptr(smem_mean), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - // S2R: Smem to Reg - auto tiled_s2r = make_tiled_copy(Copy_Atom{}, - Layout< Shape<_1, ThreadCount>, - Stride<_0, _1>>{}, - Layout<_1>{}); - auto thr_s2r = tiled_s2r.get_slice(args.thread_idx); - Tensor tSR_sAlpha = thr_s2r.partition_S(sAlpha); - Tensor tSR_sBias = thr_s2r.partition_S(sBias); - Tensor tSR_sMean = thr_s2r.partition_S(sMean); - Tensor tSR_sInvStddev = thr_s2r.partition_S(sInvStddev); - Tensor tSR_sColAlpha = thr_s2r.partition_S(sColAlpha); - Tensor tSR_sColBias = thr_s2r.partition_S(sColBias); - Tensor tSR_cAlpha = thr_s2r.partition_S(args.cD); - - Tensor tSR_rAlpha = make_tensor_like(take<0,3>(tSR_sAlpha)); // need to check - Tensor tSR_rBias = make_tensor_like(take<0,3>(tSR_sBias)); - Tensor tSR_rMean = make_tensor_like(take<0,3>(tSR_sMean)); - Tensor tSR_rInvStddev = make_tensor_like(take<0,3>(tSR_sInvStddev)); - - Tensor tCsAlpha = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sColAlpha, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCsBias = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sColBias, args.epi_tile, args.tiled_copy, args.thread_idx); - - Tensor tCrAlpha = make_tensor_like(take<0,5>(tCsAlpha)); // (CPY,CPY_M,CPY_N) - Tensor tCrBias = make_tensor_like(take<0,5>(tCsBias)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tSR_rAlpha), cute::move(tSR_rBias), - cute::move(tSR_rMean), cute::move(tSR_rInvStddev), - cute::move(tSR_sAlpha), cute::move(tSR_sBias), - cute::move(tSR_sMean), cute::move(tSR_sInvStddev), - cute::move(tSR_cAlpha), - cute::move(tSR_sColAlpha), cute::move(tSR_sColBias), - cute::move(tCrAlpha), cute::move(tCrBias), - cute::move(tCsAlpha), cute::move(tCsBias), - ThreadCount{}, - params_ptr); - } -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index d026b15ccacef0bb199b7a98172c722f9402d075..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,666 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree store operations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/detail/helper_macros.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -namespace detail { - template - CUTLASS_DEVICE auto - compute_quantized_with_row_scalefactor( - Array& frg_compute, - Array& frg_sf, - ElementCompute norm_constant) - { - cutlass::multiplies mul; - cutlass::multiplies> mul_array; - - Array frg_output; - auto output_frgs = reinterpret_cast *>(frg_output.data()); - auto compute_frgs = reinterpret_cast *>(frg_compute.data()); - - Array qpvscale_rcps = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate>{}(frg_sf); - return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_ups = cutlass::NumericArrayConverter{}(frg_sf); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_ups); - } - }(); - - // norm_constant and qpvscale_rcps are all positive numbers. - auto acc_scales = cutlass::multiplies>{}(norm_constant, qpvscale_rcps); - - CUTLASS_PRAGMA_UNROLL - for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - // Map INF to fp32::max - auto acc_scale = minimum_with_nan_propagation{}(acc_scales[sf_v], cutlass::platform::numeric_limits::max()); - // Convert to output type - output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); - } - return frg_output; - } -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BlockScaleFactor Generation Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BlockScaleFactorRowStore { - static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or - size<1>(EpilogueTile{}) / SFVecSize == 2 or - size<1>(EpilogueTile{}) / SFVecSize == 4 or - size<1>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - struct SharedStorage { }; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = nullptr; - ElementCompute const* norm_constant_ptr = nullptr; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (N % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorRowStore() { } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class CoordGTensor, - class ThrResidue, - class EpiTileCoordMN, - class ElementType - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) - GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - CoordGTensor tC_cSFD_, // (m,n) - ThrResidue residue_tC_cSFD_, // (m,n) - Params const* params_ptr_, - EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) - ElementType norm_constant_, - ElementType norm_constant_scaled_down_) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , epi_tile_coord_mn(epi_tile_coord_mn_){} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - EpiTileCoordMN epi_tile_coord_mn; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) - { - static_assert(FragmentSize % SFVecSize == 0, "Scale factor vector size should divide FragmentSize"); - constexpr int NumVecs = FragmentSize / SFVecSize; - Array frg_compute; - - auto input_frgs = reinterpret_cast const*>(frg_input.data()); - auto compute_frgs = reinterpret_cast *>(frg_compute.data()); - - Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) - - cutlass::multiplies mul; - cutlass::maximum_absolute_value_reduction, true> amax_reduction; - - cutlass::Array vec_maxs; - cutlass::Array pvscales; - // SF generation - CUTLASS_PRAGMA_UNROLL - for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); - /// Step1: get max across a vector - vec_maxs[sf_v] = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); - } - - /// Step2: Compute Scale - pvscales = cutlass::multiplies>{}(vec_maxs, norm_constant_scaled_down); - - tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); - - Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); - Tensor tCrSFD_flt = filter_zeros(tC_rSFD); - constexpr auto MCL = decltype(max_common_layout(tCgSFD_flt, tCrSFD_flt)){}; - constexpr int V = cute::min(4, size(MCL)); - using VecType = uint_bit_t>; - Tensor tCgSFD_vec = recast(coalesce(tCgSFD_flt)); - Tensor tCrSFD_vec = recast(coalesce(tCrSFD_flt)); - Tensor tCcSFD_pred = tC_cSFD(_,_,_, epi_m, epi_n); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSFD_vec); i++){ - if (elem_less(tCcSFD_pred(i * SFVecSize * V), residue_tC_cSFD)) { - tCgSFD_vec(i) = tCrSFD_vec(i); - } - } - /// Step3: Compute quantized output values - return detail::compute_quantized_with_row_scalefactor(frg_compute, tC_rSFD_frg(_0{}), norm_constant); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; - tile_coord_l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); -#if 0 - if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ - print("epi_tile ");print(args.epi_tile); print("\n"); - print("mSFD ");print(mSFD); print("\n"); - print("gSFD ");print(gSFD); print("\n"); - print("tCgSFD ");print(tCgSFD); print("\n"); - print("tCrSFD ");print(tCrSFD); print("\n"); - print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); - print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); - } -#endif - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - args.tCcD, - args.residue_tCcD, - params_ptr, - epi_tile_coord_mn, - norm_constant, - norm_constant_scaled_down); - - } -}; - -template < - int SFVecSize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BlockScaleFactorColStore { - - static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or - size<0>(EpilogueTile{}) / SFVecSize == 2 or - size<0>(EpilogueTile{}) / SFVecSize == 4 or - size<0>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - static constexpr int NumSyncWarps = SFVecSize == 64 ? 4 : 0; - static constexpr int NumSyncThreads = NumSyncWarps * NumThreadsPerWarp; - struct SharedStorage { - array_aligned smem_aux; - }; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = nullptr; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - // BlockScaleFactor generation is per batch or group - // For Ptr-Array GEMM and Grouped GEMM, ElementBlockScaleFactor is ElementType* - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (M % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorColStore] M-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorColStore() { } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class EpiTileCoordMN, - class ElementType - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - // Normally, we should use tile_shape_mnk to tile the gtensor. - // However, the SF gtensor could not be divisible by non-pow2 cta tile, so we use epi tile (pow2) to do tiling. - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) - GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - STensor&& sAmaxs_, // (NumSyncWarps) - CoordGTensor tC_cSFD_, // (m,n) - ThrResidue residue_tC_cSFD_, // (m,n) - Params const* params_ptr_, - EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) - ElementType norm_constant_, - ElementType norm_constant_scaled_down_) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , epi_tile_coord_mn(epi_tile_coord_mn_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - EpiTileCoordMN epi_tile_coord_mn; - - CUTLASS_DEVICE - ElementCompute find_amax(ElementCompute max) { - // Overall idea: after TMEM_LOAD.32DP32bit pattern, each thread in the warp can load adjacent elements of a column into its private RF. - // Here we are using shuffle instructons to the amax value of the adjacent column elements. - // For VS16, t0~t15 would generate an amax, and t16~t31 would generate another one. - // For VS32, t0~t31 should generate an amax. - // For VS64, t0~t63 should generate an amax. We would first do the reduciton within a warp, - // and then use smem to do inter-warp reduction. - if constexpr (SFVecSize == 32) { - return cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); - } - else if constexpr (SFVecSize == 16) { - return cutlass::redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31{}(max); - } - else if constexpr (SFVecSize == 64) { - // Get abs_max per warp - auto abs_max = cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); - - // Switch the amax of adjacent warps - const bool leading_thread = (threadIdx.x % NumThreadsPerWarp) == 0; - const int warp_idx = threadIdx.x / NumThreadsPerWarp % 4; - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(NumSyncThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - // Inter-warp reduction for VS=64 - // Only 4 * FP32 = 16 bytes smem is needed as we have 4 warps. - if (leading_thread) { - sAmaxs(warp_idx) = abs_max; - } - synchronize(); - // Switch data between two adjacent warps to do reduction - float tmp = sAmaxs(warp_idx^1); - synchronize(); - abs_max = cutlass::maximum_with_nan_propagation{}(abs_max,tmp); - return abs_max; - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported VecSize"); - } - } - - template - CUTLASS_DEVICE auto - compute_quantized_value(Array compute, Array sf) { - cutlass::multiplies> mul_array; - auto qpvscale_rcp = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcps = cutlass::reciprocal_approximate>{}(sf); - return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcps); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericArrayConverter{}(sf); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. - auto acc_scale = mul_array(norm_constant, qpvscale_rcp); - // Map INF to fp32::max - acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - return mul_array(compute, acc_scale); - } - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) - { - constexpr int NumVecs = 1; // each thread only compute 1 col scalefactors - Array frg_compute; - Array frg_output; - Array frg_scale_float; - Array frg_amax; - Array frg_scale; - - Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) - - cutlass::multiplies mul; - cutlass::multiplies> mul_array; - /// convert acc to Element Compute - auto compute_frgs = NumericArrayConverter{}(frg_input); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - /// Step1: get max across a vector - frg_amax[i] = find_amax(compute_frgs[i]); - } - - frg_scale_float = mul_array(frg_amax, norm_constant_scaled_down); - frg_scale = cutlass::NumericArrayConverter{}(frg_scale_float); - auto tC_cSFD_pred = tC_cSFD(_,_,_,epi_m,epi_n); - auto tC_gSFD_store = tC_gSFD(_,_,_,_,_,get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n); - for (int i=0; i < cute::ceil_div(FragmentSize, SFVecSize); i++) { - int idx = i * SFVecSize + threadIdx.x % SFVecSize; - if (idx < FragmentSize && elem_less(tC_cSFD_pred(idx), residue_tC_cSFD)) { - UnderlyingElementBlockScaleFactor tmp = frg_scale[idx]; - // Store the (EpilogueTile / SFVecSize) elements. - tC_gSFD_store(idx) = tmp; - } - } - - /// Step3: Compute quantized output values - if constexpr (cute::sizeof_bits_v == 4) { - return compute_quantized_value(compute_frgs, frg_scale); // ElementCompute - } - else { - // 6bits or 8bits output. - compute_frgs = compute_quantized_value(compute_frgs, frg_scale); - frg_output = cutlass::NumericArrayConverter{}(compute_frgs); - return frg_output; // ElementOutput - } - - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; - tile_coord_l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - //Tensor gSFD = local_tile(mSFD, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); - // Normally, we should use tile_shape_mnk to tile the mSFD tensor. However, we could not do it for non-pow2 cta tile with vectorsize = 32. - // For scale factor, 128x4 elements are stored in a basic block, and the layout of mSFD is ((_32,_4,int),(_32,_4,int),int):((_16,_4,int),(_0,_1, int),int) - // If we tiled it using tile_shape_mnk(128, 192), the N mode would encounter shape_div failure because (32, 4) could not be divisible by 192. - // Therefore, switching to using pow2 epilogue tile. - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor(make_smem_ptr(smem_aux), make_layout(_4{})); -#if 0 - if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ - print("mSFD ");print(mSFD); print("\n"); - print("gSFD ");print(gSFD); print("\n"); - print("tCgSFD ");print(tCgSFD); print("\n"); - print("tCrSFD ");print(tCrSFD); print("\n"); - print("args.tCcD ");print(args.tCcD); print("\n"); - print("args.residue_tCcD ");print(args.residue_tCcD); print("\n"); - print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); - print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); - } -#endif - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - epi_tile_coord_mn, - norm_constant, - norm_constant_scaled_down); - } -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index b769b1f0fbe2aa78f0ee97da442fb61c1aa49cc8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,1593 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -/*! \file - \brief Fusion callbacks specializations for the SM120 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Sm120 Tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// D = alpha * acc + beta * C -// With BlockScaleFactor Generation. -// 1. Find max of 32 F32 elements -// 2. Convert the max to UE8 (or UE4M3) and store the result. -// 3. Convert the UE8 (or UE4M3) back to F32 scale. -// 4. Reciprocal of F32 scale with MUFU. -// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombRowBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - - using Sm100Fusion = FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile - >; - using Operation = typename Sm100Fusion::Operation; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per-row bias -// with row blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// with row blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per_col bias -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per_col bias) -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -// with per column blockScaled generation -// 1. Find max of 32 F32 elements -// 2. Convert the max to UE8 (or UE4M3) and store the result. -// 3. Convert the UE8 (or UE4M3) back to F32 scale. -// 4. Reciprocal of F32 scale with MUFU. -// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombColBlockScaleFactor = Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle>, - Sm90LinearCombination< - ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle> - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized< - StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>, - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute,ElementBlockScaleFactor, - cutlass::layout::ColumnMajor, ElementSource, ElementScalar, RoundStyle>, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = Sm120LinearCombColBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - - using Sm100Fusion = FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile - >; - using Operation = typename Sm100Fusion::Operation; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per-Col bias -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerColBiasColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per_col bias) -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerColBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - - using Impl = - Sm120LinCombPerRowBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// D = alpha * acc + beta * C + per-row bias -// with per column blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// Sm120 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// For Ptr-Array and Grouped GEMM -// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombRowBlockScaleFactorPtrArray = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor *, RoundStyle - >, // gen scalefactor - Sm90LinearCombinationPtrArray< ElementCompute, ElementCompute, - ElementSource, ElementScalar, RoundStyle - > // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = - Sm120LinearCombRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - >; - - using Operation = - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// For Ptr-Array and Grouped GEMM -// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombEltActRowBlockScaleFactorPtrArray = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor *, RoundStyle - >, // gen scalefactor - Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - fusion::LinCombEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombEltActRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = - Sm120LinCombEltActRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - >; - - using Operation = - fusion::LinCombEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index e72e971bd8d99f87a2528af3c1dbd27366298ef5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,899 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -/*! \file - \brief Visitor tree store operations for the SM120 TMA warp-specialized (ws) epilogue -*/ - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BlockScaleFactor Generation Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm120BlockScaleFactorRowStore { - - static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or - size<1>(EpilogueTile{}) / SFVecSize == 2 or - size<1>(EpilogueTile{}) / SFVecSize == 4 or - size<1>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - - static constexpr int NumWarpgroups = 2; - static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; - static constexpr int NumQuadsPerWarp = 8; - static constexpr int NumSyncQuads = NumSyncWarps * NumQuadsPerWarp; - struct SharedStorage { - array_aligned smem_aux; - }; - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = {}; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = {}; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (N % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorRowStore() { } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class TileCoordMN, - class ElementType, - class TiledCopy_ - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, - GTensor&& tC_gSFD_, - STensor&& sAmaxs_, - CoordGTensor tC_cSFD_, - ThrResidue residue_tC_cSFD_, - Params const* params_ptr_, - TileCoordMN tile_coord_mn_, - ElementType norm_constant_, - ElementType norm_constant_scaled_down_, - int thread_idx_, - TiledCopy_ const&) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , tile_coord_mn(tile_coord_mn_) - , thread_idx(thread_idx_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - TileCoordMN tile_coord_mn; - int thread_idx; - static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; - static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); - static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; - static_assert(NumCollaboratingWarpGroups == 1 || NumCollaboratingWarpGroups == 2, - "SM120 epilogue currently only supports one or two warp groups collaborating."); - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) { - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - /* - Accumulator fragments are distributed across quads in different warps. - For SFVector = 16, we have: - - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 - Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 - - - - - - In this case, row-wise scale factors are cooperatively reduced across 4 - threads from 1 quad in 1 warp. Each quad computes its own, local absolute - maximum without communicating with other warps through shared memory. - - For SFVector = 32, we have: - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - - - - - - For SFVector = 64, we have: - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 - Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 - Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 - Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 - - - - Thus, rowwise scale factors are cooperatively reduced across 8 threads - from two quads in two warps. Each quad first computes its own, local - absolute maximum and then shares this with the corresponding quad in the - other warp. In this case, a reduction through shared memory is needed. - - For a non-cooperative epilogue (in which each warpgroup computes a - separate tile), the pattern is the same as that above, except that warps 0 - and 2 are in the same row, and 1 and 3 are in the same row, and warps 4-7 - are not included. - */ - - // Accumulator fragments consist of two elements from two different rows of a 16x8 MMA output - static constexpr int ColsPerThreadAccFrag = 2; - static constexpr int RowsPerThreadAccFrag = 2; - static_assert(FragmentSize == - (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); - - static constexpr int NumThreadsPerQuad = 4; - static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); - // A quad from two or four warps participate in computing each scale factor. - constexpr int WarpsPerSF = SFVecSize / 16; - static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); - - constexpr bool IsInterWarpReductionNeeded = (WarpsPerSF != 1); - - // Number of fragments for each thread that are needed for computing a scale factor - static constexpr int AccFragsPerSF = SFVecSize / (ColsPerThreadAccFrag * NumThreadsPerQuad * WarpsPerSF); - static_assert(size<2>(visit_results) % AccFragsPerSF == 0, - "Fragments along N mode must be a multiple of the number of accumulator fragments needed per SF"); - - auto warp_idx = thread_idx / NumThreadsPerWarp; - auto warpgroup_idx = thread_idx / NumThreadsPerWarpGroup; - auto quad_idx_in_warp = (thread_idx % NumThreadsPerWarp) / NumThreadsPerQuad; - auto thread_idx_in_quad = thread_idx % NumThreadsPerQuad; - - cutlass::maximum_absolute_value_reduction amax_op; - cutlass::multiplies mul; - - Tensor tC_rSFD_flt = filter_zeros(tC_rSFD); - - auto synchronize = [&] () { - cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - }; - - CUTLASS_PRAGMA_UNROLL - for (int sf_id = 0; sf_id < size(tC_rSFD_flt); ++sf_id) { - - auto coord = idx2crd(sf_id, tC_rSFD_flt.shape()); - auto row_in_acc = get<0,1,1>(coord); - auto row = crd2idx(get<1>(coord), get<1>(tC_rSFD_flt.shape())); - auto sf = crd2idx(get<2>(coord), get<2>(tC_rSFD_flt.shape())); - - // - // Compute amax for this scale factor - // - ElementCompute amax{0}; - - // Compute amax among vals owned by this thread for this vector - auto acc_frag_row = row_in_acc * RowsPerThreadAccFrag; - auto acc_frag_start_for_sf = sf * AccFragsPerSF; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccFragsPerSF; ++i) { - auto acc_frg = visit_results(0, row, acc_frag_start_for_sf + i); - amax = amax_op(amax, acc_frg[acc_frag_row]); - amax = amax_op(amax, acc_frg[acc_frag_row + 1]); - } - - // At this point, each thread has computed the amax of the values that it owns for this SF vector. - // We now need to compute the amax across threads. Because the TiledMMA uses an MmaThrLayout of <4,1,1>, - // we know that all fragments in this row will belong to threads in this warp. Furthermore, because - // SM120 narrow-precision MMAs have 16x8 output size with a quad owning two rows, we know that a quad - // will own all of the elements to be reduced via amax. Therefore, we can use warp shuffle intrinsics - // among threads in one quad to compute the amax. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < 3; ++i) { - auto amax_other = __shfl_xor_sync(0xffffffff, amax, i); - amax = amax_op(amax, amax_other); - } - - if constexpr (IsInterWarpReductionNeeded) { - // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its quad - // that should be used in computing the amax for this SF. Threads 0 in each quad of warps 0 and 2 - // (similarly, 1 and 3) now exchange amaxes to compute the final amax. - if (thread_idx_in_quad == 0) { - sAmaxs(quad_idx_in_warp, warp_idx) = amax; - } - synchronize(); - - // Get the amax broadcasted by the warp with which we share. - // Work on 4 warps per SFD generation - if constexpr (WarpsPerSF == 4) { - if constexpr (NumCollaboratingWarpGroups == 2) { - // This implementation assumes warp layout 2 x 4. - // For cooperative kernels (NumCollaboratingWarpGroups=2), - // warp 0 shares with 2 / 4 / 6, warp 1 shares with 3 / 5/ 7. - auto amax_other2 = sAmaxs(quad_idx_in_warp, warp_idx ^ 2); - auto amax_other4 = sAmaxs(quad_idx_in_warp, warp_idx ^ 4); - auto amax_other6 = sAmaxs(quad_idx_in_warp, warp_idx ^ 6); - synchronize(); - amax = amax_op(amax, amax_other2); - amax = amax_op(amax, amax_other4); - amax = amax_op(amax, amax_other6); - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported warp layout."); - } - } - // Work on 2 warps per SFD generation - else if constexpr(WarpsPerSF == 2) { - // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares - // with 4, 1 shares with 5, etc. For non-cooperative kernels - // (NumCollaboratingWarpGroups=1), 0 shares with 2, 1 shares with 3. - auto amax_other = sAmaxs( - quad_idx_in_warp, warp_idx ^ (1 << NumCollaboratingWarpGroups)); - synchronize(); - amax = amax_op(amax, amax_other); - } - } - - ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); - tC_rSFD_flt(coord) = qpvscale; - - // - // Apply the scale factor to the output - // - ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - - // Compute quantized output values - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccFragsPerSF; ++i) { - auto acc_frag = visit_results(0, row, acc_frag_start_for_sf + i); - visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row ] = mul(acc_frag[acc_frag_row], acc_scale); - visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row + 1] = mul(acc_frag[acc_frag_row + 1], acc_scale); - } - } // sf - - // Since scale factors are computed cooperatively across two quads from two warps, we only need one thread from the - // set of 8 cooperating threads to write out the data. We do this with thread 0 in each quad of the first warp that collaborates. - bool write_sf = (thread_idx_in_quad == 0); - if constexpr (NumCollaboratingWarpGroups == 2) { - // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares with 4, 1 shares with 5, etc. - // Thus, only the warps in the first warpgroup need to write out scale factors. - if constexpr (IsInterWarpReductionNeeded) { - write_sf &= warp_idx < NumWarpsPerWarpGroup; - } - } - else { - if constexpr (IsInterWarpReductionNeeded) { - // When non-cooperative kernels apply inter warp reduce, they are with - // SF output rule as below : - // 1. warp 0 shares with 2 and 1 shares with 3 within each warpgroup. - // 2. warps 0 and 1 of the first warpgroup and 4 and 5 of the second - // warpgroup need to write output sf. - write_sf &= ((warp_idx < 2) || (warpgroup_idx == 1 && warp_idx < 6)); - } - } - - if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { - copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[l]; - l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor( - make_smem_ptr(smem_aux), - make_layout(make_shape(Int{}, Int{})) - ); - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - tile_coord_mn, - norm_constant, - norm_constant_scaled_down, - args.thread_idx, - args.tiled_copy); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm120BlockScaleFactorColStore { - - static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or - size<0>(EpilogueTile{}) / SFVecSize == 2 or - size<0>(EpilogueTile{}) / SFVecSize == 4, - "Possible store in interleaved 4B aligned format"); - - static constexpr int NumWarpgroups = 2; - static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; - static constexpr int NumThreadsPerQuad = 4; - static constexpr int NumSyncElementsCrossWarp = NumSyncWarps * NumThreadsPerQuad; - struct SharedStorage { - array_aligned smem_aux; - }; - - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = {}; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = {}; - NormalConstStrideMNL norm_constant_stride = {}; - }; - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (M % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorColStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorColStore() { } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class TileCoordMN, - class ElementType, - class TiledCopy_ - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, - GTensor&& tC_gSFD_, - STensor&& sAmaxs_, - CoordGTensor tC_cSFD_, - ThrResidue residue_tC_cSFD_, - Params const* params_ptr_, - TileCoordMN tile_coord_mn_, - ElementType norm_constant_, - ElementType norm_constant_scaled_down_, - int thread_idx_, - TiledCopy_ const&) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , tile_coord_mn(tile_coord_mn_) - , thread_idx(thread_idx_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - TileCoordMN tile_coord_mn; - int thread_idx; - static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; - static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); - static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; - static_assert(NumCollaboratingWarpGroups == 2, - "SM120 epilogue currently only supports two warp groups collaborating."); - static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) { - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - /* - Accumulator fragments are distributed across threads/quads in different warps. For column major, the - reduction happens along M dimension. For SFVector = 32, we have: - - 8 elements 8 elements 8 elements 8 elements - + <----------------------><----------------------><----------------------><----------------------> - | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - | ... ... ... ... - 1 | Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - 6 | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - | ... ... ... ... - + Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 - | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 - 1 | ... ... ... ... - 6 | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 - | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 - | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 - | ... ... ... ... - | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 - - - - In this case, colum-wise scale factors are cooperatively reduced across 8 threads from 2 warps. - Each column first computes its own, local absolute maximum and then shares this with the - corresponding threads in the other warp. In this case, a reduction through shared memory is needed. - - For SFVector = 64, the reduction happens inside 4 warps: warp 0/1/2/3 and warp 4/5/6/7. - */ - - // Accumulator fragments consist of two elements from two different columns of a 16x8 MMA output - static constexpr int RowsPerThreadAccFrag = 2; - static constexpr int ColsPerThreadAccFrag = 2; - static_assert(FragmentSize == (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); - - static constexpr int NumThreadsPerCol = NumThreadsPerWarp / NumThreadsPerQuad; - constexpr int WarpsPerSF = SFVecSize / NumThreadsPerCol / ColsPerThreadAccFrag; - static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); - - auto warp_idx = thread_idx / NumThreadsPerWarp; - auto thread_idx_in_warp = thread_idx % NumThreadsPerWarp; - - cutlass::maximum_absolute_value_reduction amax_op; - cutlass::multiplies mul; - - auto synchronize = [&] () { - // When WarpsPerSF equals 1, data processing is inside warp, there is no needs to have the sync. - static constexpr bool NoSyncNeeded = (WarpsPerSF == 1); - if(NoSyncNeeded) - return; - cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - }; - - CUTLASS_PRAGMA_UNROLL - for(int mma_in_epi = 0; mma_in_epi < size<1>(tC_rSFD)*size<2>(tC_rSFD); ++mma_in_epi) { - - CUTLASS_PRAGMA_UNROLL - for (int sf_id = 0; sf_id < ColsPerThreadAccFrag; ++sf_id) { - - // - // Compute amax for this scale factor - // - ElementCompute amax{0}; - - // Compute amax among vals owned by this thread for this vector - auto acc_frg = visit_results(mma_in_epi); - amax = amax_op(amax, acc_frg[sf_id]); - amax = amax_op(amax, acc_frg[sf_id + ColsPerThreadAccFrag]); - - // At this point, each thread has computed the amax of the values that it owns for this SF vector. - // We now need to compute the amax across threads. Because SM120 narrow-precision MMAs have 16x8 output - // size with a quad owning two rows, we know that 8 threads in one column will own all of the 16 elements - // to be reduced via amax. Therefore, we can use warp shuffle intrinsics among threads to compute the amax. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < NumThreadsPerCol; ++i) { - auto amax_other = __shfl_xor_sync(0xffffffff, amax, (i * NumThreadsPerQuad)); - amax = amax_op(amax, amax_other); - } - - // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its - // threads that should be used in computing the amax for this SF. - if (thread_idx_in_warp < NumThreadsPerQuad && WarpsPerSF != 1) { - sAmaxs(thread_idx_in_warp, warp_idx) = amax; - } - - synchronize(); - - // Get the amax broadcasted by the warp with which we share. - // For cooperative kernels, when scale factor vector size is 32 (WarpsPerSF equals 2), - // warp 0 shares with 1, warp2 shares with 2, etc. - // When vector size is 64 (WarpsPerSF equals 4), warp 0 shares with 1/2/3, and 4 shares with 5/6/7. - // When vector size is 16, no needs to swap between warps. - if constexpr (2 == WarpsPerSF) { - auto amax_other = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); - amax = amax_op(amax, amax_other); - } - else if constexpr (4 == WarpsPerSF) { - auto amax_other1 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); - auto amax_other2 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 2); - auto amax_other3 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 3); - amax = amax_op(amax, amax_other1); - amax_other2 = amax_op(amax_other2, amax_other3); - amax = amax_op(amax, amax_other2); - } - synchronize(); - - ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); - filter(tC_rSFD)(sf_id + mma_in_epi*ColsPerThreadAccFrag) = qpvscale; - - // - // Apply the scale factor to the output - // - ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - - // Compute quantized output values - visit_results(mma_in_epi)[sf_id ] = mul(acc_frg[sf_id ], acc_scale); - visit_results(mma_in_epi)[sf_id + ColsPerThreadAccFrag] = mul(acc_frg[sf_id + ColsPerThreadAccFrag], acc_scale); - } // end for sf_id - } // end for mma_in_epi - - // Since scale factors are computed cooperatively across two or four warps, we only need one thread from the - // cooperating column threads group to write out the data. - bool write_sf = (thread_idx_in_warp < NumThreadsPerQuad); - if constexpr (2 == WarpsPerSF) { - // Output warp {0, 2, 4, 6}. - write_sf &= ((warp_idx & 0x1) == 0); - } - else if constexpr (4 == WarpsPerSF) { - // Output warp {0, 4}. - write_sf &= ((warp_idx & 0x3) == 0); - } - else if constexpr (1 == WarpsPerSF) { - // Output warp {0, 1, ..., 7}. Keep write_sf as is. - } - - if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { - copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[l]; - l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - static_assert(size<0>(EpilogueTile{}) && ((size<0>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), - "Epilogue Tile N should be pow of 2"); - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), - Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor( - make_smem_ptr(smem_aux), - make_layout(make_shape(Int{}, Int{})) - ); - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - tile_coord_mn, - norm_constant, - norm_constant_scaled_down, - args.thread_idx, - args.tiled_copy); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index 95e8208686ead6606040ee280023a7f5b879b07b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,2792 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" - -#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Sm90EVT = Sm90TreeVisitor; - -// D = alpha * acc -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledAcc, - CtaTileShapeMNK, - EpilogueTile -> : Sm90EVT, - Sm90ScalarBroadcast>, - Sm90AccFetch - > { - using Impl = - Sm90EVT, - Sm90ScalarBroadcast>, - Sm90AccFetch - >; - using Operation = fusion::ScaledAcc; - - struct Arguments { - // Give a name and flat ordering to the fusion callback args - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - - // Conversion to the args expected by the visitor implementation - // to_underlying_arguments will implicitly call this - operator typename Impl::Arguments() const { - return - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -template< - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombination = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinearCombination, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombination; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch -template< - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombinationPtrArray = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcastPtrArray>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcastPtrArray>, // alpha - Sm90AccFetch // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - fusion::LinearCombination, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombination; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C) -template< - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombEltAct = - Sm90EVT, // activation(beta * C + (alpha * acc)) - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombEltAct, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombEltAct { - - using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltAct; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch -template< - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombEltActPtrArray = - Sm90EVT, // activation(beta * C + (alpha * acc)) - Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - fusion::LinCombEltAct, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombEltActPtrArray { - - using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltAct; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBias, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { - using Impl = Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - using Operation = fusion::LinCombPerRowBias< - ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-column bias -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBias, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { - using Impl = Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - using Operation = fusion::LinCombPerColBias< - ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBiasEltAct = - Sm90EVT, - Sm90LinCombPerRowBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-column bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBiasEltAct = - Sm90EVT, - Sm90LinCombPerColBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// Aux = alpha * acc + beta * C + per-row bias) -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBiasEltActAux = - Sm90EVT, - Sm90EVT, - Sm90LinCombPerRowBias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90LinCombPerRowBiasEltActAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerRowBiasEltActAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerRowBiasEltActAux< - GmemLayoutTagAux, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : activation(store(beta * C + (alpha * acc + bias))) - { // unary op : store(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// D = activation(alpha * acc + beta * C + per_col bias) -// Aux = alpha * acc + beta * C + per_col bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBiasEltActAux = - Sm90EVT, - Sm90EVT, - Sm90LinCombPerColBias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90LinCombPerColBiasEltActAux< - StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerColBiasEltActAux< - StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerColBiasEltActAux< - GmemLayoutTagAux, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : activation(store(beta * C + (alpha * acc + bias))) - { // unary op : store(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = per-row alpha * acc + per-row beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerRowLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerRowLinCombPerRowBiasEltAct = - Sm90EVT, - Sm90PerRowLinCombPerRowBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - using StrideAlpha = Stride; - using StrideBeta = Stride; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - StrideAlpha dAlpha = {bool(1), _0{}, 0}; - StrideBeta dBeta = {bool(1), _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = per-col alpha * acc + per-col beta * C + per-column bias -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColLinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColLinCombPerColBiasEltAct = - Sm90EVT, - Sm90PerColLinCombPerColBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerColLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerColLinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerColLinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerColLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,bool,int64_t>; - using StrideBeta = Stride<_0,bool,int64_t>; - StrideAlpha dAlpha = {_0{}, bool(1), 0}; - StrideBeta dBeta = {_0{}, bool(1), 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C -template< - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColResAddPerColBiasEltAct = - Sm90EVT, // beta * C + activation(alpha * acc + bias) - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // activation(alpha * acc + bias) - Sm90EVT, // alpha * acc + bias - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - > - >; - - template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerColResAddPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerColResAddPerColBiasEltAct< - CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerColResAddPerColBiasEltAct< - CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerColResAddPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,bool,int64_t>; - using StrideBeta = Stride<_0,bool,int64_t>; - StrideAlpha dAlpha = {_0{}, bool(1), 0}; - StrideBeta dBeta = {_0{}, bool(1), 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + activation(alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // unary op : activation(alpha * acc + bias) - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; - -// We only apply the scaling factor if output is fp8 -template -struct ScaleOutOp { template using Op = cutlass::first; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; - -template -using amax = cutlass::maximum_absolute_value_reduction; // propogate nans - -}; // end namespace detail - -// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltAct = - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // activation(Z) - // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerRowBias - >, - Sm90ScalarBroadcast // scale_d - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d - { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltAct = - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // activation(Z) - // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerColBias - >, - Sm90ScalarBroadcast // scale_d - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90ScaledLinCombPerColBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerColBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d - { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z - -// fp8 aux specialization -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = - Sm90SplitTreeVisitor< - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90SplitTreeFetch // Z - > - >, - Sm90ScalarBroadcast // scale_d - >, - // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) - Sm90EVT, // store(Aux) - Sm90EVT, // Z * scale_aux - Sm90EVT, // amax_aux - Sm90SplitTreeFetch // Z - >, - Sm90ScalarBroadcast // scale_aux - > - > - >; - -// non-fp8 aux specialization -// lets us use some EVT specializations such as relu + uint1b_t aux -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90EVT, // Aux = Z - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias - > - > - >, - Sm90ScalarBroadcast // scale_d - >; - -// dispatcher -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle - >, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > ->; - - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementAmax, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerRowBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - ElementScalar scale_aux = ElementScalar(1); - ElementScalar const* scale_aux_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - ElementAmax* amax_D_ptr = nullptr; - ElementAmax* amax_aux_ptr = nullptr; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - // Only compute amax_d if D is fp8 - ElementAmax* amax_D_ptr_ = nullptr; - if constexpr (detail::is_fp8_v) { - amax_D_ptr_ = amax_D_ptr; - } - - // Aux is fp8 -> DAG arguments - if constexpr (detail::is_fp8_v) { - typename Impl::Arguments args; - // always use structured binding to unpack DAG args since it may or may not be a tuple - auto& [Z_args, aux_args, D_args] = args; - - Z_args = - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - - D_args = - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - {}, // leaf args : Z - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - - aux_args = - { // unary op : store(Aux) - { // binary op : Z * scale_d or Z - { // unary op : reduce(Z) - {}, // leaf args : Z - {amax_aux_ptr} // unary args : reduce - }, // end unary op - {{scale_aux}, - {scale_aux_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies - }, // end binary op - {aux_ptr, dAux} // unary args : store - }; // end unary op - - return args; - } - - // Aux is not fp8 -> Tree arguments - else { - return - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - { // unary op : store(Z) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias - }, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d},{scale_d_ptr}}, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z - -// fp8 aux specialization -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = - Sm90SplitTreeVisitor< - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias - Sm90ScaledLinCombPerColBias, - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90SplitTreeFetch // Z - > - >, - Sm90ScalarBroadcast // scale_d - >, - // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) - Sm90EVT, // store(Aux) - Sm90EVT, // Z * scale_aux - Sm90EVT, // amax_aux - Sm90SplitTreeFetch // Z - >, - Sm90ScalarBroadcast // scale_aux - > - > - >; - -// non-fp8 aux specialization -// lets us use some EVT specializations such as relu + uint1b_t aux -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90EVT, // Aux = Z - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerColBias - > - > - >, - Sm90ScalarBroadcast // scale_d - >; - -// dispatcher -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, - Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle - >, - Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > ->; - - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementAmax, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerColBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerColBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerColBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - ElementScalar scale_aux = ElementScalar(1); - ElementScalar const* scale_aux_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - ElementAmax* amax_D_ptr = nullptr; - ElementAmax* amax_aux_ptr = nullptr; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - // Only compute amax_d if D is fp8 - ElementAmax* amax_D_ptr_ = nullptr; - if constexpr (detail::is_fp8_v) { - amax_D_ptr_ = amax_D_ptr; - } - - // Aux is fp8 -> DAG arguments - if constexpr (detail::is_fp8_v) { - typename Impl::Arguments args; - // always use structured binding to unpack DAG args since it may or may not be a tuple - auto& [Z_args, aux_args, D_args] = args; - - Z_args = - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - - D_args = - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - {}, // leaf args : Z - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - - aux_args = - { // unary op : store(Aux) - { // binary op : Z * scale_d or Z - { // unary op : reduce(Z) - {}, // leaf args : Z - {amax_aux_ptr} // unary args : reduce - }, // end unary op - {{scale_aux}, - {scale_aux_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies - }, // end binary op - {aux_ptr, dAux} // unary args : store - }; // end unary op - - return args; - } - - // Aux is not fp8 -> Tree arguments - else { - return - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - { // unary op : store(Z) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias - }, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d},{scale_d_ptr}}, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpS2R, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombDeEltAct = - Sm90EVT, // activation(beta * C + (alpha * acc), aux) - Sm90LinearCombination, // beta * C + (alpha * acc) - Sm90AuxLoad // aux - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementSource, - class ElementScalar, - int AlignmentAux, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpS2R -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombDeEltAct< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpS2R -> : Sm90LinCombDeEltAct< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - > { - - using Impl = - Sm90LinCombDeEltAct< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >; - using Operation = - fusion::LinCombDeEltAct< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux const* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // binary op : activation(beta * C + (alpha * acc), aux) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, ElementAux(0), dAux}, // leaf args : aux - activation // binary args : activation - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpS2R, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombDeEltActDePerRowBias = - Sm90EVT, // Identity for final conversion - Sm90EVT, AlignmentBias>, - Sm90LinCombDeEltAct - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpS2R -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombDeEltActDePerRowBias< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpS2R -> : Sm90LinCombDeEltActDePerRowBias< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombDeEltActDePerRowBias< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombDeEltActDePerRowBias< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux const* aux_ptr = nullptr; - StrideAux dAux = {}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias* dbias_ptr = nullptr; - StrideBias dDbias = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : identity/convert - { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) - { // binary op : activation(beta * C + (alpha * acc), aux) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, ElementAux(0), dAux}, // leaf args : aux - activation // binary args : activation - }, // end binary op - {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce - }, // end unary op - {} // unary args : identity/convert - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = softmax(top_k(alpha * acc + beta * C)) -template< - int TopK, - int FragmentSize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombTopKSoftmaxCol = - Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int TopK, - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombTopKSoftmaxCol, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombTopKSoftmaxCol { - - using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombTopKSoftmaxCol; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {} // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Grouped Wgrad Conv -template< - class GroupsPerTile, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombinationGroupedWgrad = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast>, // alpha - Sm90AccFetchGroupedWgrad // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class GroupsPerTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinearCombinationGroupedWgrad, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombinationGroupedWgrad; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - //ElementScalar groups = ElementScalar(1); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { -template > -struct get_element_aux { - using type = void; -}; - -template -struct get_element_aux> { - using type = typename FusionOpOrCallbacks::ElementAux; -}; - -template -struct get_element_aux, cute::void_t<>> { - using type = typename get_element_aux::type; -}; - -template -struct get_element_aux, cute::void_t::Operation>> { - private: - using Operation = typename FusionCallbacks::Operation; - public: - using type = typename get_element_aux::type; -}; -} // namespace cutlass:epilogue::fusion::detail - -template -using get_element_aux_t = typename detail::get_element_aux::type; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp deleted file mode 100644 index ae63a7675c12dc4329374815da4d081a6bd885ee..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ /dev/null @@ -1,842 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// N-nary Elementwise Compute Operation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// The template argument provided for ComputeFn must be able to accept -// exactly one template parameter. In Standard C++, it's OK for -// ComputeFn to have other template parameters, as long as those have -// defaults. For example, the following struct Foo would work. -// -// template -// struct Foo { -// CUTLASS_HOST_DEVICE auto operator() (A a, B b); -// }; -// -// However, some compilers, such as Clang, require that the argument -// take _exactly_ one template parameter. This is nonstandard C++ -// behavior. One work-around for this case is to create a subclass -// with exactly one template parameter, and then use that subclass as -// the template argument. -// -// template -// struct FooHomogeneous : public Foo {}; -// -template< - template class ComputeFn, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class = void -> -struct Sm90Compute { -private: - using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; - - template - struct ComputeArguments { - using type = EmptyArguments; - }; - - // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters - template - struct ComputeArguments> { - using type = typename Fn::Arguments; - }; - -public: - struct SharedStorage { }; - - using Arguments = typename ComputeArguments>::type; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const&, Arguments const&) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90Compute() - : params() {} - - CUTLASS_HOST_DEVICE - Sm90Compute(Params const& params, SharedStorage const& shared_storage) - : params(params) {} - - Params const params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Params const& params) - : params(params) {} - - Params const& params; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) { - return transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) CUTLASS_LAMBDA_FUNC_INLINE { - using ElementInput = typename cute::remove_cvref_t::Element; - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - return convert_input(frg_input); - }, - [&] (auto&&... cvt_frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - using ComputeOutput = ComputeFn>; - ComputeOutput compute_output{}; - - if constexpr (cute::is_same_v) { - using ElementComputeOutput = - typename cute::remove_cvref_t::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - return convert_output(compute_output(cvt_frg_inputs...)); - } - else { - using ElementComputeOutput = - typename cute::remove_cvref_t::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - return convert_output(compute_output(cvt_frg_inputs..., params)); - } - } - ); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(params); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Performance Optimized Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// beta * C + Z -template < - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class InputScaleOp, // beta - class ElementSource, // C - class InputAddOp // Z -> -struct Sm90TreeVisitor< - Sm90Compute().is_zero())>>, - InputScaleOp, - Sm90SrcFetch, - InputAddOp -> : Sm90VisitorImpl< - InputScaleOp, - Sm90SrcFetch, - InputAddOp, - Sm90Compute - > -{ - using Impl = - Sm90VisitorImpl< - InputScaleOp, - Sm90SrcFetch, - InputAddOp, - Sm90Compute - >; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor( - Params const& params, - SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - auto const& scale_op = get<0>(Impl::ops); - auto const& added_op = get<2>(Impl::ops); - if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { - return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || - is_C_load_needed() || - added_op.is_producer_load_needed(); - } - else { - return is_C_load_needed() || added_op.is_producer_load_needed(); - } - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - auto const& scale_op = get<0>(Impl::ops); - auto const& src_op = get<1>(Impl::ops); - auto const& added_op = get<2>(Impl::ops); - return (not scale_op.is_zero() && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); - } - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) - : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } - - bool is_C_load_needed; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - using ElementZ = typename decltype(frg_added)::Element; - using ConvertZ = NumericArrayConverter; - using ConvertI = NumericArrayConverter; - ConvertZ convert_Z{}; - ConvertI convert_I{}; - - Array frg_I = convert_Z(frg_added); - - if constexpr (!is_void_v) { - Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - using ElementX = typename decltype(frg_scalar)::Element; - using ElementY = typename decltype(frg_source)::Element; - using ConvertX = NumericArrayConverter; - using ConvertY = NumericArrayConverter; - using ComputeI = multiply_add>; - ConvertX convert_X{}; - ConvertY convert_Y{}; - ComputeI compute_I{}; - - frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); - } - - return convert_I(frg_I); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); - bool is_C_load_needed = this->is_C_load_needed(); - if (not is_C_load_needed) { - cute::clear(args.tCrC); - } - return ConsumerStoreCallbacks( - is_C_load_needed, std::move(callbacks_tuple)); - } -}; - -// ReLU with aux bit tensor dReLU/dZ -// Aux(i) = Z(i) >= 0 ? 1 : 0 -namespace detail { -// Placeholder node so we can retain standard EVT structure -template -struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { - struct SharedStorage {}; - - struct Arguments { - cutlass::uint1b_t* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90ReLUAuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } -}; -} // namespace detail - -// Specialization on the generic compute+aux EVT -template < - // Compute node - template class Activation, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - // Aux node - int Stages, - class EpilogueTile, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpR2S, - int Alignment, - bool EnableNullptr, - // Input node - class InputOp -> -struct Sm90TreeVisitor< - Sm90Compute, cutlass::epilogue::thread::ReLu> || - cute::is_same_v, cutlass::epilogue::thread::Clamp> || - cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, - Sm90TreeVisitor< - Sm90AuxStore< - Stages, - EpilogueTile, - cutlass::uint1b_t, - RoundStyle, - StrideMNL, - SmemLayoutAtom, - CopyOpR2S, - Alignment, - EnableNullptr - >, - InputOp - > -> : Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - > -{ - using Impl = - Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - >; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) - : params(params_), Impl(params_, shared_storage) {} - - Params const& params; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rAux, - GTensor&& tC_gAux, - CTensor tC_cAux, - ThrResidue residue_tC_cAux, - Params const& params, - CallbacksImpl&& impl) - : tC_rAux(cute::forward(tC_rAux)), - tC_gAux(cute::forward(tC_gAux)), - tC_cAux(tC_cAux), - residue_tC_cAux(residue_tC_cAux), - params(params), - CallbacksImpl(cute::forward(impl)) {} - - RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tC_cAux; - Params const& params; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - // Unpack callbacks + params - auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; - auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - // Visit the input node - Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); - - // Compute activation + aux - using ElementInput = typename decltype(frg_input)::Element; - using ConvertInput = NumericArrayConverter; - using ConvertAux = PackPredicates; - using ComputeOutput = Activation; - using ConvertOutput = NumericArrayConverter; - ConvertInput convert_input{}; - ComputeOutput relu{}; - ConvertAux convert_aux{}; - ConvertOutput convert_output{}; - - Array frg_compute = convert_input(frg_input); - bool frg_aux[FragmentSize]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - ElementCompute pre_relu = frg_compute[i]; - if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || - cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { - frg_compute[i] = relu(frg_compute[i], params_compute); - } - else { - frg_compute[i] = relu(frg_compute[i]); - } - if constexpr (cute::is_same_v) { - uint32_t aux; - asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux - frg_aux[i] = static_cast(aux); - } else if constexpr (cute::is_same_v) { - uint32_t aux; - cutlass::half_t compute = frg_compute[i]; - asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux - frg_aux[i] = static_cast(aux); - } else { - frg_aux[i] = frg_compute[i] == pre_relu; - } - } - - static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); - Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) - tC_rAux_frg(epi_v) = convert_aux(frg_aux); - - return convert_output(frg_compute); - } - - CUTLASS_DEVICE void - end() { - // Unpack callbacks + params - auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; - auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - // Visit the input node - callbacks_input.end(); - - // Nullptr is no-op - if constexpr (EnableNullptr) { - if (params_aux.ptr_aux == nullptr) { - return; - } - } - - // Compute vectorization - constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - // Copy vectorizes into byte-aligned stores - if constexpr (V > 1 && V % 8 == 0) { - using VecType = uint_bit_t; - Tensor tC_rAux_vec = recast(tC_rAux); - Tensor tC_gAux_vec = recast(tC_gAux); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); - } - // sub-byte vectorization, must serialize threads - else { - // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) - int lane_idx = canonical_lane_idx(); - Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - CUTLASS_PRAGMA_NO_UNROLL - for (int i = 0; i < NumThreadsPerWarp; ++i) { - if (lane_idx == i) { - copy_if(tC_pAux, tC_rAux, tC_gAux); - } - __syncwarp(); - } - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - // Unpack params - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(params_aux.ptr_aux); - Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params, cute::move(callbacks_impl)); - } -}; - -// Aux load for uint1b_t -template < - int Stages, - class EpilogueTile, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment, - bool EnableNullptr -> -struct Sm90AuxLoad< - Stages, - EpilogueTile, - cutlass::uint1b_t, - StrideMNL, - SmemLayoutAtom, - CopyOpS2R, - Alignment, - EnableNullptr -> { - static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); - - struct SharedStorage {}; - - struct Arguments { - cutlass::uint1b_t const* ptr_aux = nullptr; - cutlass::uint1b_t null_default = cutlass::uint1b_t(0); - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const&) - : params(params) { } - - Params const params; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ThrResidue residue_tC_cAux_, Params const& params_) - : tC_rAux(cute::forward(tC_rAux_)), - tC_gAux(cute::forward(tC_gAux_)), - tC_cAux(tC_cAux_), - residue_tC_cAux(residue_tC_cAux_), - params(params_) {} - - RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) - GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tC_cAux; - Params const& params; - - CUTLASS_DEVICE void - begin() { - if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - return; - } - } - - constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - if constexpr (V > 1) { - using VecType = uint_bit_t; - Tensor tC_gAux_vec = recast(tC_gAux); - Tensor tC_rAux_vec = recast(tC_rAux); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); - } - else { - Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux, tC_gAux, tC_rAux); - } - } - } - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - return; - } - } - - Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); - } - } - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - using ElementRegister = typename remove_cvref_t::value_type; - if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { - return recast>(coalesce(tC_rAux))(epi_v); - } - else { - return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(params.ptr_aux); - Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - - // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences - constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; - Tensor tC_rAux = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (V % 8 != 0) { - return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - } else { - return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - } - }(); - - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - fill(tC_rAux, params.null_default); - } - } - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params); - } -}; - -// dReLU specialization -template< - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle -> -struct Sm90Compute< - cutlass::epilogue::thread::dReLU, - ElementOutput, - ElementCompute, - RoundStyle -> : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input, - Array const& frg_aux) { - using ConvertInput = NumericArrayConverter; - using ComputeOutput = cutlass::epilogue::thread::dReLU>; - using ConvertOutput = NumericArrayConverter; - ConvertInput convert_input{}; - ComputeOutput compute_output{}; - ConvertOutput convert_output{}; - - return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp deleted file mode 100644 index 535d8b082d44ff796fe2efc4e1531b4a3dc2674c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ /dev/null @@ -1,1492 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Fetch Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// returns accumulator -struct Sm90AccFetch : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return frg_acc; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks{}; - } -}; - -// Split tree visitor fetches intermediate results from temporary accumulators -using Sm90SplitTreeFetch = Sm90AccFetch; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// returns C -template -struct Sm90SrcFetch : Sm90VisitorImpl<> { - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return is_C_load_needed(); - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return not is_void_v; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_void_v; - } - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(SrcTensor const& tCrC) - : tCrC(tCrC) {} - - SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return recast>(tCrC)(epi_v); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - // register type may differ from logical type so we can't assert matching types here - return ConsumerStoreCallbacks(args.tCrC); - } -}; - -// returns accumulator in Grouped Conv Wgrad -template -struct Sm90AccFetchGroupedWgrad : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - using GroupsPerTile = GroupsPerTile_; - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(int32_t thread_idx) - : thread_idx(thread_idx) { } - - int32_t thread_idx; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - - Array frg_acc_rst; - int warp_id = thread_idx / 32; - - // In Grouped Wgrad, only diagonal block data is valid and the others is wrong and useless. - // One block size is C/G x C/G. Note that C/G = Tile_N / GroupsPerTile. - // Copy diagonal block ACC into the first block Col which is the output tensor size Tile_M * C/G. - // Then we can store the valid output tensor tile directly. - if constexpr ( cute::is_same_v ) { - frg_acc_rst = frg_acc; - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 16; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id / 2 * 16]; - } - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id * 8]; - } - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id * 8 + i / 2 * 4]; - } - } - - return frg_acc_rst; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(args.thread_idx); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Load Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90AuxLoad { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); - // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) - using SmemShapeTma = decltype(make_shape( - max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), - max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); - using SmemLayoutTma = decltype(tile_to_shape( - SmemLayoutAtom{}, SmemShapeTma{}, - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayout = decltype(tile_to_shape( - SmemLayoutTma{}, - make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - using CopyOpG2S = - SM90_TMA_LOAD - ; - - struct SharedStorage { - alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) - array_aligned smem_aux; - }; - - struct Arguments { - Element const* ptr_aux = nullptr; - Element null_default = Element(0); - StrideMNL dAux = {}; - }; - - struct Params { - using TMA_Aux = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), - take<0,2>(SmemLayoutTma{}))); - TMA_Aux tma_load_aux; - Element null_default = Element(0); - bool use_default = false; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto M_AUX = - size(M) - ; - Tensor tensor_aux = make_tensor(make_gmem_ptr(args.ptr_aux), make_layout(make_shape(M_AUX,N,L), append<3>(args.dAux, _0{}))); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(CopyOpG2S{}, tensor_aux, take<0,2>(SmemLayoutTma{})); - - bool use_default = false; - if constexpr (EnableNullptr) { - use_default = args.ptr_aux == nullptr; - } - - return Params{tma_load_aux, args.null_default, use_default}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr; - Element* smem_aux; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (params_ptr->use_default && params_ptr->null_default == Element(0)); - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) - : bGS_gAux(cute::forward(bGS_gAux)), - bGS_sAux(cute::forward(bGS_sAux)), - params_ptr(params_ptr) {} - - GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) - Params const* params_ptr; - - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - if constexpr (EnableNullptr) { - if (params_ptr->use_default) { - return; - } - } - - if (issue_tma_load) { - // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size - constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA load - constexpr uint16_t mcast_mask = 0; - int load_pipe_index = load_iteration % Stages; - copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), - bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - auto coord_shape = - make_coord(m, n, l) - ; - Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) - - Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) - - ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); - Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - - return ProducerLoadCallbacks( - cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) - : tC_rAux(cute::forward(tC_rAux)), - tiled_s2r(tiled_s2r), - tSR_sAux(cute::forward(tSR_sAux)), - params_ptr(params_ptr) { } - - TiledS2R tiled_s2r; - RTensor tC_rAux; // (CPY,CPY_M,CPY_N) - STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) - Params const* params_ptr; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if constexpr (EnableNullptr) { - if (params_ptr->use_default) { - fill(tC_rAux, params_ptr->null_default); - return; - } - } - - using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); - Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) - - int load_pipe_index = load_iteration % Stages; - copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) - - return tC_rAux_frg(epi_v); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - - Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); - Tensor tC_gAux = sm90_partition_for_epilogue(mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - - auto tiled_s2r = conditional_return( - make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), - make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) - ); - Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); - } -}; - -template < - class Element, - class EpilogueTile, // Unused - class LayoutOrStrideMNL, - class SmemLayoutAtom, // Unused - class CopyOpS2R, // Unused - int Alignment, - bool EnableNullptr -> -struct Sm90AuxLoad< - 0, EpilogueTile, Element, LayoutOrStrideMNL, - SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr -> { - using ElementAux = Element; - using StrideMNL = cutlass::gemm::TagToStrideC_t; - - struct SharedStorage { }; - - struct Arguments { - Element const* ptr_aux = nullptr; - Element null_default = Element(0); - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template< - class GTensorG2R, - class RTensor, - class CTensorG2R, - class ProblemShapeMNL - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, - RTensor&& tC_rAux, - CTensorG2R&& tC_cAux, - ProblemShapeMNL problem_shape_mnl, - Params const* params_ptr) - : tC_gAux(cute::forward(tC_gAux)), - tC_rAux(cute::forward(tC_rAux)), - tC_cAux(cute::forward(tC_cAux)), - problem_shape_mnl(problem_shape_mnl), - params_ptr(params_ptr) {} - - GTensorG2R tC_gAux; - RTensor tC_rAux; - CTensorG2R tC_cAux; - ProblemShapeMNL problem_shape_mnl; - Params const* params_ptr; - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if constexpr (EnableNullptr) { - if (params_ptr->ptr_aux == nullptr) { - fill(tC_rAux, params_ptr->null_default); - return; - } - } - constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); - Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - - copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return recast>(tC_rAux)(epi_v); - } - }; - - template < - bool ReferenceSrc, - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - auto problem_shape_mnl = make_shape(M,N,L); - - // Gmem Tensor - Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux - ); - Tensor tC_gAux = sm90_partition_for_epilogue( - mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - // Register Tensor - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); - - // Predication support - Tensor coordAux = make_identity_tensor(shape(mAux)); - Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tC_gAux), - cute::move(tC_rAux), - cute::move(tC_cAux), - problem_shape_mnl, - params_ptr - ); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Broadcast Load Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Scalar broadcast -// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors -template< - class Element, - class StrideMNL_ = Stride<_0,_0,_0>, - int BroadcastCount = 1, - template class ReductionFn = multiplies -> -struct Sm90ScalarBroadcast { - using StrideMNL = StrideMNL_; - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - - struct SharedStorage { }; - - struct Arguments { - Element scalars[BroadcastCount] = {}; - Element const* scalar_ptrs[BroadcastCount] = {}; - StrideMNL dScalar[BroadcastCount] = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter *cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - // This must be called after update_scalar is called - CUTLASS_DEVICE bool - is_zero() const { - if (get<2>(params_ptr->dScalar[0]) == 0) { - // Only 1 batch - return scalar == Element(0); - } - else { - // multiple batch - if (valid_scalar == false) { - // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. - return params_ptr->scalar_ptrs[0] == nullptr; - } - else { - // Check whether each batch is ZERO or not. - return scalar == Element(0); - } - } - } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { - // Get the scalar for non-batched broadcast - if (size<2>(params_ptr->dScalar[0]) == 0) { - update_scalar(); - } - } - - Element scalar; - bool valid_scalar = false; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Get the scalar for batched broadcast - if (size<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } - - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Element scalar) - : scalar(scalar) {} - - Element scalar; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_scalar; - frg_scalar.fill(scalar); - - return frg_scalar; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } - - return ConsumerStoreCallbacks(scalar); - } - -private: - CUTLASS_DEVICE void - update_scalar(int l_coord = 0) { - valid_scalar = true; - int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); - - if (params_ptr->scalar_ptrs[0] != nullptr) { - scalar = params_ptr->scalar_ptrs[0][l_offset]; - } - else { - // batch stride is ignored for nullptr fallback - scalar = params_ptr->scalars[0]; - } - - // Do reduction over multiple broadcasts if necessary - ReductionFn reduction_fn; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < BroadcastCount; ++i) { - if (params_ptr->scalar_ptrs[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } - else { - // batch stride is ignored for nullptr fallback - scalar = reduction_fn(scalar, params_ptr->scalars[i]); - } - } - } - - template - CUTLASS_DEVICE void - update_scalar(cute::tuple) { - // Only support multiple L-modes with fully-broadcast scalar - scalar = params_ptr->scalars[0]; - valid_scalar = true; - } -}; - -// Scalar broadcast -// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors -template< - class Element, - class StrideMNL_ = Stride<_0,_0,_0>, - int BroadcastCount = 1, - template class ReductionFn = multiplies -> -struct Sm90ScalarBroadcastPtrArray { - using StrideMNL = StrideMNL_; - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - - struct SharedStorage { }; - - struct Arguments { - Element scalars[BroadcastCount] = {}; - Element const* scalar_ptrs[BroadcastCount] = {}; - Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; - StrideMNL dScalar[BroadcastCount] = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter *cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - // producer load is needed if Element is not void - return !cute::is_void_v; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - // This must be called after update_scalar is called - CUTLASS_DEVICE bool - is_zero() const { - return scalar == Element(0); - } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcastPtrArray() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { - // Get the scalar for non-batched broadcast - if (size<2>(params_ptr->dScalar[0]) == 0) { - update_scalar(); - } - } - - Element scalar; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Always refresh scalar with the current group index so per-group - // alpha/beta values (provided through pointer arrays) are loaded - // correctly even when the L-stride is zero. - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Element scalar) - : scalar(scalar) {} - - Element scalar; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_scalar; - frg_scalar.fill(scalar); - - return frg_scalar; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - - return ConsumerStoreCallbacks(scalar); - } - -private: - CUTLASS_DEVICE void - update_scalar(int l_coord = 0) { - int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); - - if (params_ptr->scalar_ptr_arrays[0] != nullptr) { - // Pointer-array variant: each entry already points to the scalar of a group. - scalar = *(params_ptr->scalar_ptr_arrays[0][l_coord]); - } - else if (params_ptr->scalar_ptrs[0] != nullptr) { - // Strided pointer variant. - scalar = params_ptr->scalar_ptrs[0][l_offset]; - } - else { - // Literal fallback. - scalar = params_ptr->scalars[0]; - } - - // Do reduction over multiple broadcasts if necessary - ReductionFn reduction_fn; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < BroadcastCount; ++i) { - - if (params_ptr->scalar_ptr_arrays[i] != nullptr) { - scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][l_coord])); - } - else if (params_ptr->scalar_ptrs[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } - else { - scalar = reduction_fn(scalar, params_ptr->scalars[i]); - } - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -[[deprecated("row broadcast only uses 0 stages")]] constexpr int -compute_row_broadcast_stages() { - return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; -} - -} - -// Row vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class ElementInput_, - class ElementCompute = cute::remove_pointer_t, - class StrideMNL_ = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v>, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90RowBroadcast { - using StrideMNL = StrideMNL_; - // Get base element input type. - using ElementInput = cute::remove_pointer_t; - // Check if input is an array of pointers. - static constexpr bool IsArrayOfPointers = is_same_v; - using PtrRowType = cute::conditional_t; - - static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); - - static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast - static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); - - struct SharedStorage { - array_aligned(CtaTileShapeMNK{})> smem; - }; - - struct Arguments { - PtrRowType ptr_row = nullptr; - ElementInput null_default = ElementInput(0); - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90RowBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), is_zero_(false), - smem(const_cast(shared_storage.smem.data())) { - auto const& [stride_M, stride_N, stride_L] = params.dRow; - // Nullptr default - if (EnableNullptr && params.ptr_row == nullptr) { - is_zero_ = params.null_default == ElementCompute(0); - } - // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { - if constexpr (!IsArrayOfPointers) { - is_zero_ = params.ptr_row[0] == ElementInput(0); - } - } - } - - Params params; - bool is_zero_ = false; - ElementInput *smem = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_zero_; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, - GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, - SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, - Residue residue_cRow_, Params const& params_) - : tGS_gRow(tGS_gRow_) - , tGS_sRow(tGS_sRow_) - , tGS_cRow(tGS_cRow_) - , tiled_G2S(tiled_g2s_) - , tSR_sRow(tSR_sRow_) - , tSR_rRow(tSR_rRow_) - , residue_cRow(residue_cRow_) - , params(params_) { - } - - GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) - GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) - GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) - Tiled_G2S tiled_G2S; - - SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - Residue residue_cRow; // (m, n) - Params const& params; - - CUTLASS_DEVICE void - begin() { - bool is_nullptr = EnableNullptr && params.ptr_row == nullptr; - - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); - - for (int i = 0; i < size(tGS_gRow_flt); ++i) { - if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - continue; // OOB of SMEM, - } - if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) { - tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load - } - else { - tGS_sRow_flt(i) = params.null_default; // fill OOB values so smem to RF load can issue without predication - } - } - } - - CUTLASS_DEVICE bool - begin_sync_needed() const { - return true; // Ensure visibility of async gmem to smem loads - } - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if (epi_m == 0) { // Assumes M-major subtile loop - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); - copy_aligned(tSR_sRow_flt, tSR_rRow_flt); - - constexpr int FrgSize = size(tSR_rRow_flt); - using FrgInput = Array; - using FrgCompute = Array; - using ConvertInput = NumericArrayConverter; - - Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); - Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); - ConvertInput convert_input{}; - - tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); - } - - return frg_row; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using ThreadCount = decltype(size(args.tiled_copy)); - - auto layout_N = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - auto shape_N = get<1>(args.problem_shape_mnkl); - if constexpr (IsDynamicBroadcast) { - auto stride_N = repeat_like(shape_N, int(0)); - if (get<1>(params.dRow) == bool(1)) { - stride_N = transform_leaf(compact_major(shape_N), - [] (auto const& stride) { return static_cast(stride); } - ); - } - return make_layout(shape_N, stride_N); - } - else { - return make_layout(shape_N); - } - }(); - - auto layout_M = make_layout(M, repeat_like(M, _0{})); - auto layout_L = make_layout(L, get<2>(params.dRow)); - ElementInput const* ptr_row = nullptr; - if constexpr(IsArrayOfPointers) { - if (!(EnableNullptr && params.ptr_row == nullptr)) { - ptr_row = params.ptr_row[l]; - } - } else { - ptr_row = params.ptr_row; - } - Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) - //// G2S: Gmem to Smem - auto tiled_g2s = make_tiled_copy(Copy_Atom{}, - Layout< Shape<_1, ThreadCount>, - Stride<_0, _1>>{}, - Layout<_1>{}); - auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); - - //// G2S: Coord - Tensor tGS_cRow = thr_g2s.partition_S(args.cD); - - //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) - - return ConsumerStoreCallbacks( - tGS_gRow, - tGS_sRow, - tGS_cRow, tiled_g2s, - tSR_sRow, - tSR_rRow, - args.residue_cD, - params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class ElementInput_, - class ElementCompute = cute::remove_pointer_t, - class StrideMNL_ = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v>, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90ColBroadcast { - using StrideMNL = StrideMNL_; - // Get base element input type. - using ElementInput = cute::remove_pointer_t; - // Check if input is an array of pointers. - static constexpr bool IsArrayOfPointers = is_same_v; - using PtrColType = cute::conditional_t; - - static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); - - static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast - static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); - - // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem - struct SharedStorage { }; - - struct Arguments { - PtrColType ptr_col = nullptr; - ElementInput null_default = ElementInput(0); - StrideMNL dCol = {}; - }; - - struct Params { - PtrColType ptr_col = nullptr; - ElementCompute null_default = ElementCompute(0); - StrideMNL dCol = {}; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_zero_; - } - - CUTLASS_HOST_DEVICE - Sm90ColBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), is_zero_(false) { - auto const& [stride_M, stride_N, stride_L] = params.dCol; - // Nullptr default - if (EnableNullptr && params.ptr_col == nullptr) { - is_zero_ = params.null_default == ElementCompute(0); - } - // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { - if constexpr (!IsArrayOfPointers) { - is_zero_ = params.ptr_col[0] == ElementInput(0); - } - } - } - - Params params; - bool is_zero_; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) - : tCgCol(tCgCol_), - tCrCol(tCrCol_), - tCcCol(tCcCol_), - residue_tCcCol(residue_tCcCol_), - params(params_) { - if (EnableNullptr && params.ptr_col == nullptr) { - fill(tCrCol, params.null_default); - } - } - - GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcCol; - Params const& params; - - CUTLASS_DEVICE void - begin() { - if (EnableNullptr && params.ptr_col == nullptr) { - return; - } - - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - Tensor tCgCol_flt = filter_zeros(tCgCol); - Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); - Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); - - constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - if constexpr (V > 1) { - using VecType = uint_bit_t>; - Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); - Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); - Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); - Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); - copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec); - } - else { - Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); - copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt); - } - - constexpr int FrgSize = size(tCrCol_flt); - using FrgInput = Array; - using FrgCompute = Array; - using ConvertInput = NumericArrayConverter; - - Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); - Tensor tCrCol_compute_frg = recast(filter(tCrCol)); - ConvertInput convert_input{}; - - tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); - } - - return frg_col; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - auto shape_M = get<0>(args.problem_shape_mnkl); - if constexpr (IsDynamicBroadcast) { - auto stride_M = repeat_like(shape_M, int(0)); - if (get<0>(params.dCol) == bool(1)) { - stride_M = transform_leaf(compact_major(shape_M), - [] (auto const& stride) { return static_cast(stride); } - ); - } - return make_layout(shape_M, stride_M); - } - else { - return make_layout(shape_M); - } - }(); - - auto layout_N = make_layout(N, repeat_like(N, _0{})); - auto layout_L = make_layout(L, get<2>(params.dCol)); - ElementInput const* ptr_col = nullptr; - if constexpr(IsArrayOfPointers) { - if (!(EnableNullptr && params.ptr_col == nullptr)) { - ptr_col = params.ptr_col[l]; - } - } else { - ptr_col = params.ptr_col; - } - Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L)); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L)); - Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Batch matrix broadcast -// Only need to redefine this if we can multicast across cluster L -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -using Sm90MatrixBroadcast - = Sm90AuxLoad; - -namespace detail { - -template -struct IsScalarBroadcast { - static constexpr bool value = false; -}; - -template -struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { - static constexpr bool value = true; -}; - -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index 06ad8082e57cedf4d16aecdad8a995e838e1c93e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,1722 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Store Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class EpilogueTile, - class Element, - FloatRoundStyle RoundStyle, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpR2S, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90AuxStore { - using ElementAux = Element; - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); - // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) - using SmemShapeTma = decltype(make_shape( - max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), - max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); - using SmemLayoutTma = decltype(tile_to_shape( - SmemLayoutAtom{}, SmemShapeTma{}, - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayout = decltype(tile_to_shape( - SmemLayoutTma{}, - make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - - struct SharedStorage { - alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) - array_aligned smem_aux; - }; - - struct Arguments { - Element* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - struct Params { - using TMA_Aux = decltype(make_tma_copy( - SM90_TMA_STORE{}, - make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), - SmemLayoutTma{})); - TMA_Aux tma_store_aux; - bool is_nullptr = false; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - - bool is_nullptr = false; - if constexpr (EnableNullptr) { - is_nullptr = args.ptr_aux == nullptr; - } - - typename Params::TMA_Aux tma_store_aux; - if (not is_nullptr) { - Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); - tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); - } - - return {tma_store_aux, is_nullptr}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr; - Element* smem_aux; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class TiledR2S, - class STensorR2S, - class STensorS2G, - class GTensorS2G - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rAux, - TiledR2S tiled_r2s, - STensorR2S&& tRS_sAux, - STensorS2G&& bSG_sAux, - GTensorS2G&& bSG_gAux, - Params const* params_ptr) - : tiled_r2s(tiled_r2s), - tC_rAux(cute::forward(tC_rAux)), - tRS_sAux(cute::forward(tRS_sAux)), - bSG_sAux(cute::forward(bSG_sAux)), - bSG_gAux(cute::forward(bSG_gAux)), - params_ptr(params_ptr) {} - - TiledR2S tiled_r2s; - RTensor tC_rAux; // (CPY,CPY_M,CPY_N) - STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) - STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) - GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) - tC_rAux_frg(epi_v) = convert_input(frg_input); - - return frg_input; - } - - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - if constexpr (EnableNullptr) { - if (params_ptr->is_nullptr) { - return; - } - } - - using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); - Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) - - if (issue_smem_store) { - int store_pipe_index = store_iteration % Stages; - copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); - } - } - - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - if constexpr (EnableNullptr) { - if (params_ptr->is_nullptr) { - return; - } - } - - if (issue_tma_store) { - // Issue the TMA store - int store_pipe_index = store_iteration % Stages; - copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - - Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - auto tiled_r2s = conditional_return( - make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), - make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) - ); - auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) - - ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); - Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), - tiled_r2s, - cute::move(tRS_sAux), - cute::move(bSG_sAux), - cute::move(bSG_gAux), - params_ptr); - } -}; - -template < - class Element, - class EpilogueTile, // Unused - FloatRoundStyle RoundStyle, - class LayoutOrStrideMNL, - class SmemLayoutAtom, // Unused - class CopyOpR2S, // Unused - int Alignment, - bool EnableNullptr -> -struct Sm90AuxStore< - 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, - SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr -> { - using ElementAux = Element; - using StrideMNL = cutlass::gemm::TagToStrideC_t; - - struct SharedStorage { }; - - struct Arguments { - Element* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template< - class GTensorR2G, - class RTensor, - class CTensorR2G, - class ProblemShapeMNL - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - GTensorR2G&& tC_gAux, - RTensor&& tC_rAux, - CTensorR2G&& tC_cAux, - ProblemShapeMNL problem_shape_mnl, - Params const* params_ptr) - : tC_gAux(cute::forward(tC_gAux)), - tC_rAux(cute::forward(tC_rAux)), - tC_cAux(cute::forward(tC_cAux)), - problem_shape_mnl(problem_shape_mnl), - params_ptr(params_ptr) {} - - GTensorR2G tC_gAux; - RTensor tC_rAux; - CTensorR2G tC_cAux; - ProblemShapeMNL problem_shape_mnl; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); - tC_rAux_frg(epi_v) = convert_input(frg_input); - - return frg_input; - } - - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - if constexpr (EnableNullptr) { - if (params_ptr->ptr_aux == nullptr) { - return; - } - } - - constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); - Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - - copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); - } - }; - - template < - bool ReferenceSrc, - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - auto problem_shape_mnl = make_shape(M,N,L); - - // Gmem Tensor - Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux - ); - Tensor tC_gAux = sm90_partition_for_epilogue( - mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - // Register Tensor - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); - - // Predication support - Tensor coordAux = make_identity_tensor(shape(mAux)); - Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tC_gAux), - cute::move(tC_rAux), - cute::move(tC_cAux), - problem_shape_mnl, - params_ptr - ); - - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Reduction Store Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Scalar reduction -template < - template class RegReduceFn, - template class GmemReduceFn, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_0,_0,_0>, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90ScalarReduction { -private: - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); - -public: - struct SharedStorage { }; - - struct Arguments { - ElementOutput* ptr_scalar = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dScalar = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - #if !defined(CUTLASS_SKIP_REDUCTION_INIT) - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); - if (args.ptr_scalar != nullptr) { - return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); - } - } - #endif - - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90ScalarReduction() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params const params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - int l_coord, - CTensor tCcScalar, - ThrResidue residue_tCcScalar, - Params const& params) - : scalar(params.reduction_identity), - l_coord(l_coord), - tCcScalar(tCcScalar), - residue_tCcScalar(residue_tCcScalar), - params(params) {} - - ElementCompute scalar; - int l_coord; - CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcScalar; - Params params; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - if constexpr (EnableNullptr) { - if (params.ptr_scalar == nullptr) { - return frg_input; - } - } - - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; - - Array frg_I = convert_input(frg_input); - Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_tCcScalar)) { - scalar = reduce_input(scalar, frg_I[i]); - } - } - - return frg_input; - } - - CUTLASS_DEVICE void - end() { - if constexpr (EnableNullptr) { - if (params.ptr_scalar == nullptr) { - return; - } - } - - using ConvertI = NumericConverter; - using ReduceInput = GmemReduceFn; - - ConvertI convert_I{}; - ReduceInput reduce_input{}; - - ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); - reduce_input(ptr_scalar, convert_I(scalar)); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_tCcD, params); - } - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Row vector reduction -template < - template class RegReduceFn, - template class ShuffleReduceFn, - template class GmemReduceFn, - int Stages, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true, // Noop on nullptr params - // If this is false, ptr_row is assumed to point to a compact n-major (ceil_div(M,CTA_M), round_nearest(N,CTA_N), L) - // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput - bool FinalReduction = true, - // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true, - // Indicate the parameter order when calling RegReduceFn - // Seq length equals the number of RegReduceFn parameters - // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` - class RegReduceSeq = cute::seq<0, 1> -> -struct Sm90RowReduction { -private: - static_assert(Stages == 0, "Smem usage not supported yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); - -public: - struct SharedStorage { }; - - struct Arguments { - void* ptr_row = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dRow = {}; - }; - - struct Params { - void* ptr_row = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dRow = {}; - ElementCompute* reduction_buffer = nullptr; - int* tile_counters = nullptr; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - ElementCompute* reduction_buffer; - int* tile_counters = nullptr; - if constexpr (IsAtomic) { - reduction_buffer = nullptr; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - reduction_buffer = reinterpret_cast(workspace); - tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - } - else { - reduction_buffer = reinterpret_cast(args.ptr_row); - } - - return { - args.ptr_row, - args.reduction_identity, - args.dRow, - reduction_buffer, - tile_counters - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - if constexpr (IsAtomic || not FinalReduction) { - return 0; - } - - size_t workspace_size = 0; - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - // Increment by size of reduction buffer - workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); - if (args.ptr_row != nullptr) { - return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); - } - return Status::kSuccess; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); - } - else { - return Status::kSuccess; - } - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90RowReduction() { } - - CUTLASS_HOST_DEVICE - Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - bool do_final_reduction = false; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) { - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - return cute::get<0>(cute::make_tuple(frg_inputs...)); - } - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); - Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); - - if constexpr (VisitCheckOOB) { - using ReduceInput = RegReduceFn; - ReduceInput reduce_input{}; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { - ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); - tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) { - return ElementCompute(frg_input[i]); - }, - [&] (auto&&... cvt_frg_inputs) { - auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); - return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); - }); - } - } - } - else { - constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); - using ReduceInput = RegReduceFn>; - ReduceInput reduce_input{}; - Tensor tCrRow_mn_frg = recast>(tCrRow_mn); - - constexpr int RegFragArraySize = FragmentSize / RegFragSize; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < RegFragArraySize; ++i) { - Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); - tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) { - using ElementInput = typename cute::remove_cvref_t::Element; - using ConvertInput = NumericArrayConverter; - using RegFragArr = Array, RegFragArraySize>; - ConvertInput convert_input{}; - return convert_input(reinterpret_cast(frg_input)[i]); - }, - [&] (auto&&... cvt_frg_inputs) { - auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); - return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); - }); - } - } - return cute::get<0>(cute::make_tuple(frg_inputs...)); - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - if (not is_last_iteration) { - return; - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - auto [m, n, k, l] = tile_coord_mnkl; - constexpr bool ReferenceSrc = decltype(ref_src)::value; - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - return; - } - } - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cRow(_0{},_0{}), residue_cRow)) { - return; - } - - int lane_m = get<0>(lane_mn); - [[maybe_unused]] bool is_reduced_lane = lane_m == 0; - - // - // 1. Warp shuffle reduction - // - using FragmentShuffle = Array; - Tensor tCrRow_frg = recast(filter(tCrRow)); - using ReduceShuffle = ShuffleReduceFn; - ReduceShuffle reduce_shuffle{}; - - auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); - constexpr bool SwapShuffle = FrgSizePerLaneM > 0; - - // - // Swap Shuffle - // - // The normal way to reduction among threads: - // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. - // After each step of reduction, a half of threads won't work in the following steps. - // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). - // - // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, - // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. - // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. - // We can recursively do this until the problem size is 1. - // - if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors - Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) - CUTLASS_PRAGMA_UNROLL - for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int r = 0; r < m; ++r) { - auto frg_A = tCrRow_frg_(_,r); - auto frg_B = tCrRow_frg_(_,r + m); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < size(frg_A); ++v) { - // Step1: swap - if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second - cutlass::swap(frg_A(v), frg_B(v)); - } - - // Step2: shuffle - uint64_t frg_shfl = reinterpret_cast(frg_A(v)); - // each half of threads get a half of data from the other half of threads - frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); - - // Step3: reduction - frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); - } - } - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); - tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); - } - } - } - - // - // 2. Atomic reduction - // - if constexpr (IsAtomic) { - // Filter so we don't issue redunant copies over stride-0 modes - Tensor tCrRow_flt = filter_zeros(tCrRow); - Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - - Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); - Tensor tCgRow_flt = filter_zeros(tCgRow); - // NOTE: atomic reduction is performed in the output type - using ConvertOutput = NumericConverter; - using ReduceOutput = GmemReduceFn; - ConvertOutput convert_output{}; - ReduceOutput reduce_output{}; - - if constexpr (SwapShuffle) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FltFrgSizePerLaneM; ++i) { - int idx = lane_m * FltFrgSizePerLaneM + i; - // Only care about OOB for N mode - if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { - reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); - } - } - } - else { - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrRow_flt); ++i) { - if (elem_less(tCcRow_flt(i), residue_tCcRow)) { - reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); - } - } - } - } - sync_fn(); - } - - // - // 2. One warp in M, skip threadblock smem reduction - // - else if constexpr (decltype(size<0>(warp_layout_MN))::value <= 1) { - // Dump warp reduction to gmem workspace - using ElementGmem = cute::conditional_t; - Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); - - if constexpr (SwapShuffle) { - Tensor tCrRow_flt = filter(tCrRow); - Tensor tCgBuf_flt = recast(filter(tCgBuf)); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); - } - else { - if (is_reduced_lane) { - copy_aligned(tCrRow, recast(tCgBuf)); - } - } - sync_fn(); - } - - // - // 2. Multiple warps in M, do threadblock smem reduction - // - else { - Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); - static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= - decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), - "smem reduction buffer not large enough, use a larger epilogue tile"); - sync_fn(); - - // Dump warp reduction to smem workspace - Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); - - if constexpr (SwapShuffle) { - Tensor tCrRow_flt = filter(tCrRow); - Tensor tCsBuf_flt = filter(tCsBuf); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); - } - else { - if (is_reduced_lane) { - copy_aligned(tCrRow, tCsBuf); - } - } - sync_fn(); - - constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); - using FragmentSmem = Array; - using VectorSmem = uint_bit_t>; - using ReduceSmem = GmemReduceFn; - ReduceSmem reduce_smem{}; - - Tensor sBuf_frg = recast(filter_zeros(sBuf)); - Tensor sBuf_vec = recast(filter_zeros(sBuf)); - constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; - - constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; - using FragmentSmemArray = Array; - - // Do the threadblock smem reduction - using VectorGmem = cute::conditional_t; - Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { - FragmentSmemArray frg_smem; - - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { - int FragsCurrRows = reduction_rows * FragsPerRow; - frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); - } - - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { - frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); - } - } - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); - } - sync_fn(); - } - - // - // 3. Increment atomic counters to signal final gmem reduction - // - if constexpr (not IsAtomic && FinalReduction) { - // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); - sync_fn(); - // Collective thread 0 increments atomic tile counter and copies value to smem - int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); - if (thread_idx == 0) { - *prev_tile_count = atomicAdd(¶ms.tile_counters[n], 1); - } - sync_fn(); - // Broadcast tile count to other threads in CTA and determine final reduction status - do_final_reduction = *prev_tile_count == size<2>(gBuf_ml) * size<3>(gBuf_ml) - 1; - sync_fn(); - } - } - - CUTLASS_DEVICE void - end() { - // - // 4. Do final gmem reduction if necessary - // - if constexpr (not IsAtomic && FinalReduction) { - if (not do_final_reduction) { - return; - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - - using ReduceOutput = GmemReduceFn; - using ConvertOutput = NumericConverter; - ReduceOutput reduce_output{}; - ConvertOutput convert_output{}; - - // Reduction over batches - if (size<2>(stride(gRow_l)) == 0) { - CUTLASS_PRAGMA_NO_UNROLL - for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { - Tensor tRgBuf_ml = gBuf_ml(_0{},n,_,_); - ElementCompute output = tRgBuf_ml(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { - output = reduce_output(output, tRgBuf_ml(ml)); - } - if (elem_less(cRow(_0{},n), residue_cRow)) { - gRow_l(_0{},n,_0{}) = convert_output(output); - } - } - } - // No reduction over batches - else { - CUTLASS_PRAGMA_NO_UNROLL - for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { - bool do_store = elem_less(cRow(_0{},n), residue_cRow); - CUTLASS_PRAGMA_NO_UNROLL - for (int l = 0; l < size<3>(gBuf_ml); ++l) { - Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); - ElementCompute output = tRgBuf_m(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int m = 1; m < size(tRgBuf_m); ++m) { - output = reduce_output(output, tRgBuf_m(m)); - } - if (do_store) { - gRow_l(_0{},n,l) = convert_output(output); - } - } - } - } - - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn - - int warp_idx = args.thread_idx / NumThreadsPerWarp; - auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); - - // Partition output gmem and register tensors - auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); // (M,N,L) - Tensor gRow_l = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) - Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gRow_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - fill(tCrRow, params.reduction_identity); - - // Partition gmem+smem reduction buffer tensors - Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_0{}, _1{})); - auto block_shape = ceil_div(make_shape(M,N,L), shape(gBuf_layout)); // (M_CNT, N_CNT, L_CNT) - - // Let the M_CNT (the num of partial reduction results) become the outer mode - Layout block_layout = make_layout(block_shape, make_stride(get<1>(block_shape), _1{}, get<0>(block_shape) * get<1>(block_shape))); - Layout mBuf_layout = blocked_product(gBuf_layout, block_layout); - Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) - Tensor gBuf_ml = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(_,n,_)); // (CTA_M,CTA_N,REST_M,L) - Layout sBuf_layout = blocked_product(gBuf_layout, // (CTA_M,CTA_N,WARPS_M) - make_layout(make_shape(_1{},_1{},size<0>(warp_layout_MN)))); - - auto args_tuple = make_tuple( - bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); - return ConsumerStoreCallbacks(cute::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Col vector reduction -template < - template class RegReduceFn, - template class ShuffleReduceFn, - template class GmemReduceFn, - int Stages, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true, // Noop on nullptr params - // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) - // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput - bool FinalReduction = true, - // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true -> -struct Sm90ColReduction { -private: - static_assert(Stages == 0, "Smem usage not supported yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); - -public: - struct SharedStorage { }; - - struct Arguments { - void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dCol = {}; - }; - - struct Params { - void* ptr_col = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dCol = {}; - ElementCompute* reduction_buffer = nullptr; - int* tile_counters = nullptr; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - ElementCompute* reduction_buffer; - int* tile_counters = nullptr; - if constexpr (IsAtomic) { - reduction_buffer = nullptr; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - reduction_buffer = reinterpret_cast(workspace); - tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - } - else { - reduction_buffer = reinterpret_cast(args.ptr_col); - } - - return { - args.ptr_col, - args.reduction_identity, - args.dCol, - reduction_buffer, - tile_counters - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - if constexpr (IsAtomic || not FinalReduction) { - return 0; - } - - size_t workspace_size = 0; - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - - // Increment by size of reduction buffer - workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); - if (args.ptr_col != nullptr) { - return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); - } - return Status::kSuccess; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); - } - else { - return Status::kSuccess; - } - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90ColReduction() { } - - CUTLASS_HOST_DEVICE - Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - bool do_final_reduction = false; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - return frg_input; - } - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; - - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (!VisitCheckOOB || elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_tCcCol)) { - ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); - tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); - } - } - - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - if (not is_last_iteration) { - return; - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - auto [m, n, k, l] = tile_coord_mnkl; - constexpr bool ReferenceSrc = decltype(ref_src)::value; - - // Runtime nullptr is noop - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - return; - } - } - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { - return; - } - - // - // 1. Warp shuffle reduction - // - using FragmentShuffle = Array; - using ReduceShuffle = ShuffleReduceFn; - ReduceShuffle reduce_shuffle{}; - Tensor tCrCol_frg = recast(filter(tCrCol)); - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); - tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); - } - } - bool is_reduced_lane = get<1>(lane_mn) == 0; - - // - // 2. Atomic reduction - // - if constexpr (IsAtomic) { - // Filter so we don't issue redunant copies over stride-0 modes - Tensor tCrCol_flt = filter_zeros(tCrCol); - Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); - - Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); - Tensor tCgCol_flt = filter_zeros(tCgCol); - - // NOTE: atomic reduction is performed in the output type - using ConvertOutput = NumericConverter; - using ReduceOutput = GmemReduceFn; - ConvertOutput convert_output{}; - ReduceOutput reduce_output{}; - - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrCol_flt); ++i) { - if (elem_less(tCcCol_flt(i), residue_tCcCol)) { - reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); - } - } - } - sync_fn(); - } - - // - // 2. One warp in N, skip threadblock smem reduction - // - else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { - // Dump warp reduction to gmem workspace - using ElementGmem = cute::conditional_t; - Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - copy_aligned(tCrCol, recast(tCgBuf)); - } - sync_fn(); - } - - // - // 2. Multiple warps in N, do threadblock smem reduction - // - else { - Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); - static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= - decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), - "smem reduction buffer not large enough, use a larger epilogue tile"); - sync_fn(); - - // Dump warp reduction to smem workspace - Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - copy_aligned(tCrCol, tCsBuf); - } - sync_fn(); - - constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); - using FragmentSmem = Array; - using VectorSmem = uint_bit_t>; - using ReduceSmem = GmemReduceFn; - ReduceSmem reduce_smem{}; - - Tensor sBuf_frg = recast(filter_zeros(sBuf)); - Tensor sBuf_vec = recast(filter_zeros(sBuf)); - constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; - - // Do the threadblock smem reduction - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { - int FragsPerReduction = reduction_cols * FragsPerCol; - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); - sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } - - // Do final smem reduction and dump to gmem workspace - using VectorGmem = cute::conditional_t; - Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } - - // - // 3. Increment atomic counters to signal final gmem reduction - // - if constexpr (not IsAtomic && FinalReduction) { - // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); - sync_fn(); - // Collective thread 0 increments atomic tile counter and copies value to smem - int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); - if (thread_idx == 0) { - *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); - } - sync_fn(); - // Broadcast tile count to other threads in CTA and determine final reduction status - do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; - sync_fn(); - } - } - - CUTLASS_DEVICE void - end() { - // - // 4. Do final gmem reduction if necessary - // - if constexpr (not IsAtomic && FinalReduction) { - if (not do_final_reduction) { - return; - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - - using ReduceOutput = GmemReduceFn; - using ConvertOutput = NumericConverter; - ReduceOutput reduce_output{}; - ConvertOutput convert_output{}; - - // Reduction over batches - if (size<2>(stride(gCol_l)) == 0) { - CUTLASS_PRAGMA_NO_UNROLL - for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { - Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); - ElementCompute output = tRgBuf_nl(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { - output = reduce_output(output, tRgBuf_nl(nl)); - } - if (elem_less(cCol(m,_0{}), residue_cCol)) { - gCol_l(m,_0{},_0{}) = convert_output(output); - } - } - } - // No reduction over batches - else { - CUTLASS_PRAGMA_NO_UNROLL - for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { - bool do_store = elem_less(cCol(m,_0{}), residue_cCol); - CUTLASS_PRAGMA_NO_UNROLL - for (int l = 0; l < size<3>(gBuf_nl); ++l) { - Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); - ElementCompute output = tRgBuf_n(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int n = 1; n < size(tRgBuf_n); ++n) { - output = reduce_output(output, tRgBuf_n(n)); - } - if (do_store) { - gCol_l(m,_0{},l) = convert_output(output); - } - } - } - } - - } - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn - int warp_idx = args.thread_idx / NumThreadsPerWarp; - auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); - - // Partition output gmem and register tensors - auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) - Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - fill(tCrCol, params.reduction_identity); - - // Partition gmem+smem reduction buffer tensors - Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); - Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); - Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) - Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) - Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) - - auto args_tuple = make_tuple( - bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); - return ConsumerStoreCallbacks(std::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Batch matrix reduction -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class CopyOpR2S, - class SmemLayoutAtom, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90MatrixReduction; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp deleted file mode 100644 index 93720f8d3d71f3f4759463b5d40e604313b7e3a4..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ /dev/null @@ -1,1149 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree operation base implementation to enable composable fusions - for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using cute::tuple; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partitioning Helpers -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class CtaTileMN, - class EpilogueTile, - class TiledCopy -> -CUTLASS_HOST_DEVICE -constexpr auto -sm90_partition_for_epilogue( - CtaTileMN cT, // (CTA_M,CTA_N,...) - EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) - TiledCopy tiled_copy, - int thread_idx) { - ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); - Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) - if constexpr (ReferenceSrc) { - return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) - } - else { - return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) - } -} - -template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class Engine, class LayoutMNL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy -> -CUTLASS_HOST_DEVICE -constexpr auto -sm90_partition_for_epilogue( - Tensor mT, // (M,N,L) - TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) - TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) - EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) - TiledCopy tiled_copy, - int thread_idx) { - auto [m, n, k, l] = tile_coord_mnkl; - auto coord_shape = - make_coord(m, n, l) - ; - Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) - Tensor tCcT = - sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - return tCcT; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Visitor Implementation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Producer load callbacks, called by the epilogue load warp. -// Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation -// Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but -// are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead -// If this is non-empty, is_producer_load_needed must be true. -// -template -struct ProducerLoadCallbacksImpl { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of the subtile load loop - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Entry of the subtile load loop. Aux loads usually performed here - // Upon entry the producer acquire of the current subtile lock has completed. - // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); - } - ); - } - - // Exit of the subtile load loop. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } -}; - - -// -// Consumer store callbacks, called by the epilogue store warps. -// All operations must redefine this, with optional inheritance from this empty implementation. -// -template -struct ConsumerStoreCallbacksImpl { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of subtile store loop. Gmem broadcasts usually performed here. - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes - CUTLASS_DEVICE bool - begin_sync_needed() const { - return cute::apply(callbacks_tuple, - [] (auto const&... callbacks) { - return (false || ... || callbacks.begin_sync_needed()); - } - ); - } - - // Start of subtile store iteration - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin_loop(epi_m, epi_n); - } - ); - } - - // Before visit callback. Smem broadcasts usually performed here. - // Upon entry, all producer loads for this subtile are completed and visible. - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); - } - ); - } - - // Perform the fused elementwise computation - template - CUTLASS_DEVICE auto // returns an Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) // depends on the N-naryness of the op - = delete; // Must be implemented for each operation - - // After visit call. Smem reductions usually performed here - // reduction_buffer is an arbitrary smem tensor that can be used for workspace - // It is each nodes reponsibility to assert that this buffer is sufficiently sized - // and to ensure that this buffer is no longer needed upon callback exit - // i.e. results are synchronized and no longer in the reduction buffer - // - // visit_results is a rmem tensor that contains the results of visit() for an entire - // on the current epilogue subtile - template - CUTLASS_DEVICE void - reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); - } - ); - } - - // After reduce call, before smem async fence. Smem stores usually performed here. - // Upon exit, all smem stores for TMA must have been issued - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); - } - ); - } - - // After smem async fence, before TMA store commit. Aux stores usually performed here - // Upon exit, all TMA stores for this subtile must have been issued - // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores - // other gmem stores can be placed in the reduce or postreduce entry points - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); - } - ); - } - - // End of subtile store iteration - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end_loop(epi_m, epi_n); - } - ); - } - - // Exit of subtile store loop. Gmem reductions usually performed here. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } -}; - -template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class EpilogueTile -> -struct ProducerLoadArgs { - ProblemShapeMNKL problem_shape_mnkl; - TileShapeMNK tile_shape_mnk; - TileCoordMNKL tile_coord_mnkl; - TiledMma tiled_mma; - EpilogueTile epi_tile; - int thread_idx; - - CUTLASS_DEVICE - ProducerLoadArgs( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - EpilogueTile epi_tile, - int thread_idx) - : problem_shape_mnkl(problem_shape_mnkl), - tile_shape_mnk(tile_shape_mnk), - tile_coord_mnkl(tile_coord_mnkl), - tiled_mma(tiled_mma), - epi_tile(epi_tile), - thread_idx(thread_idx) {} -}; - -template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class EpilogueTile, - class TiledCopy, - class CoordTensor, - class Residue, - class ThrCoordTensor, - class ThrResidue, - class ThrSrcTensor -> -struct ConsumerStoreArgs { - ProblemShapeMNKL problem_shape_mnkl; - TileShapeMNK tile_shape_mnk; - TileCoordMNKL tile_coord_mnkl; - TiledMma tiled_mma; - EpilogueTile epi_tile; - TiledCopy tiled_copy; - CoordTensor cD; - Residue residue_cD; - ThrCoordTensor tCcD; - ThrResidue residue_tCcD; - ThrSrcTensor & tCrC; - int thread_idx; - - CUTLASS_DEVICE - ConsumerStoreArgs( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - CoordTensor cD, - Residue residue_cD, - ThrCoordTensor tCcD, - ThrResidue residue_tCcD, - ThrSrcTensor & tCrC, - int thread_idx) - : problem_shape_mnkl(problem_shape_mnkl), - tile_shape_mnk(tile_shape_mnk), - tile_coord_mnkl(tile_coord_mnkl), - tiled_mma(tiled_mma), - epi_tile(epi_tile), - tiled_copy(tiled_copy), - cD(cD), - residue_cD(residue_cD), - tCcD(tCcD), - residue_tCcD(residue_tCcD), - tCrC(tCrC), - thread_idx(thread_idx) {} -}; - -template -struct Sm90VisitorImplBase { - // Shared memory allocation - using SharedStorage = tuple; - // Host side fusion arguments - using Arguments = tuple; - // Device side fusion params (Kernel-entry API) - using Params = tuple; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - uint8_t* op_workspace = reinterpret_cast(workspace); - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - auto ret = Op::to_underlying_arguments(problem_shape, op_args, op_workspace); - if (op_workspace != nullptr) { - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); - } - return ret; - }, - [] (auto&&... op_params) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(op_params...); } - ); - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - return Op::can_implement(problem_shape, op_args); - }, - [&] (auto&&... implementable) CUTLASS_LAMBDA_FUNC_INLINE { - return (true && ... && implementable); - } - ); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - return round_nearest(op_workspace_size, MinWorkspaceAlignment); - }, - [&] (auto&&... op_workspace_size) CUTLASS_LAMBDA_FUNC_INLINE { - return (0 + ... + op_workspace_size); - } - ); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* op_workspace = reinterpret_cast(workspace); - return transform_apply(tuple{}, args, - // Initialize each operation's workspace, stopping at the first error - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - if (status != Status::kSuccess) { - return status; - } - - using Op = cute::remove_cvref_t; - status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream, cuda_adapter); - if (op_workspace != nullptr) { - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); - } - return status; - }, - // Return the final status - [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { return status; } - ); - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops(transform_apply(tuple{}, params, shared_storage, - [] (auto&& op, auto const& op_params, auto&& op_storage) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - return Op(op_params, op_storage); - }, - [] (auto&&... ops) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(ops...); } - )) {} - - // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) - tuple ops; -}; - -template -struct Sm90VisitorImpl : Sm90VisitorImplBase { - - using Impl = Sm90VisitorImplBase; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90VisitorImpl() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - using Impl::ops; - - // - // Queries for kernel runtime - // - - // Is a specialized warp for producer TMA loads needed - // e.g. Aux tensor loads, broadcasts using TMA bulk copy - // This condition cannot change between work tiles because it is used - // to determine whether the load warp should exit early or not - // e.g. for batched beta this must always be true regardless of current batch idx - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return cute::apply(ops, - [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { - return (false || ... || op.is_producer_load_needed()); - } - ); - } - - // Is a producer TMA load specifically for C needed - // If this is true then is_producer_load_needed must also be true - // This condition can change between work tiles because it is only used - // to determine whether the TMA and smem loads for C of a given tile should happen - // e.g. for batched beta this can be false depending on current batch idx - CUTLASS_DEVICE bool - is_C_load_needed() const { - return cute::apply(ops, - [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { - return (false || ... || op.is_C_load_needed()); - } - ); - } - - // Producer load callbacks factory - // All operations must redefine this, but most can just dispatch to the base impl - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return transform_apply(ops, - [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { - return op.get_producer_load_callbacks(args); - }, - [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - auto callbacks_tuple = cute::make_tuple(callbacks...); - return ProducerLoadCallbacksImpl{callbacks_tuple}; - } - ); - } - - // Consumer store callbacks factory - // All operations must redefine this - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return transform_apply(ops, - [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { - return op.template get_consumer_store_callbacks(args); - }, - [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - auto callbacks_tuple = cute::make_tuple(callbacks...); - return ConsumerStoreCallbacksImpl{callbacks_tuple}; - } - ); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Convenience aliases -using EmptyProducerLoadCallbacks = ProducerLoadCallbacksImpl>; -using EmptyConsumerStoreCallbacks = ConsumerStoreCallbacksImpl>; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Tree visitor -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sm90TreeVisitor : Sm90VisitorImpl { - - using Impl = Sm90VisitorImpl; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor( - Params const& params, - SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - constexpr int Rm1 = sizeof...(ChildOps); - return cute::detail::tapply(callbacks_tuple, - [&] (auto& child_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) - }, - [&] (auto&&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - }, - make_seq{} // restrict the transform to R-1 child ops, apply is for node op - ); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// DAG visitors -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Most DAG fusions can be represented as a set of output trees with a common input tree -// The common input is first evaluated, then the result is passed as the acc fragment to the output trees -template -struct Sm90SplitTreeVisitor : Sm90VisitorImpl { - - using Sm90VisitorImpl::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - constexpr int Rm2 = sizeof...(AuxOutTrees); - cute::for_each(make_seq{}, // restrict the sequence to aux out trees - [&] (auto I) CUTLASS_LAMBDA_FUNC_INLINE { - get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); - } - ); - - return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // deducing the output type for all the nodes is tricky so we just convert them all to a common type - // if multiple compute types are needed then split into multiple subgraphs grouped by type - class ElementCompute, - class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node - class... Ops // in topological order, last op is the output. EdgeTuple must match this order -> -struct Sm90TopologicalVisitor : Sm90VisitorImpl { - static_assert(is_static_v); - static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); - static_assert(sizeof...(Ops) > 1); - - using Sm90VisitorImpl::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - constexpr int Rm1 = sizeof...(Ops) - 1; - auto frg_compute_tuple = cute::repeat(Array{}); - - return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, - // Visit the first R-1 ops in topological order - [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) CUTLASS_LAMBDA_FUNC_INLINE { - frg_compute = cute::detail::apply(frg_compute_tuple, - // Compute the current op with children inputs - [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - using ElementOutput = typename decltype(frg_output)::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - - return convert_output(frg_output); - }, - // Get inputs in the sequence given by the children indices of the current op - edge_seq - ); - return frg_compute; // unused - }, - // Visit the last op - [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { - return cute::detail::apply(frg_compute_tuple, - // Compute the last op with children inputs - [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - }, - // Get inputs in the sequence given by the children indices of the last op - get(EdgeTuple{}) - ); - }, - // Transform to visit R-1 ops, apply to visit last op - make_seq{} - ); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Base specializations so we can have standard layout params and simple aggregate initializers -namespace detail { - -template -struct Sm90VisitorImplBase { - - // Retain tuple for SharedStorage because empty structs have 1B alignment - // tuples use multiple inheritance, avoids this problem - using SharedStorage = tuple< - typename Op0::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - }; - - struct Params { - typename Op0::Params op_0; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage, - typename Op2::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - typename Op2::Arguments op_2; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - typename Op2::Params op_2; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), - Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1) && - Op2::can_implement(problem_shape, args.op_2); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)), - Op2(params.op_2, get<2>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage, - typename Op2::SharedStorage, - typename Op3::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - typename Op2::Arguments op_2; - typename Op3::Arguments op_3; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - typename Op2::Params op_2; - typename Op3::Params op_3; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); - size_t op_2_workspace_size = Op2::get_workspace_size(problem_shape, args.op_2); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; - uint8_t* op_3_workspace = op_2_workspace + op_2_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), - Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace), - Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1) && - Op2::can_implement(problem_shape, args.op_2) && - Op3::can_implement(problem_shape, args.op_3); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)), - Op2(params.op_2, get<2>(shared_storage)), - Op3(params.op_3, get<3>(shared_storage)) - }) {} - - tuple ops; -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp deleted file mode 100644 index bd378419567b1680c400ec38746211a577a3c409..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp +++ /dev/null @@ -1,763 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Top-K + Softmax reduction across columns -// Performs a reduction of top-K values across N, and finally performs a softmax on them, -// and sets values not in the top-K to 0. -// -// Assumptions: -// 1. CTA_N >= N (single tile across N, the mode which is reduced) -// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one -// epilogue tile at a time.) -// 3. Top-K value is either 2 or 4. -// - -namespace detail { - -// Implementations for add to sorted list and merging sorted lists, -// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). -// Generic implementations may result in greater register use and branching, -// and should be avoided. -// Fast paths for Top-2 and Top-4 are written in inline PTX directly. - -CUTLASS_DEVICE -Array top_2_reduce_scalar(Array a, float scalar) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mx;\n" - " .reg .pred p;\n" - " max.f32 mx, %3, %4;\n" - " setp.gtu.f32 p, %2, %4;\n" - " selp.f32 %1, mx, %2, p;\n" - " selp.f32 %0, %2, %4, p;\n" - "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); - return out; -} - -CUTLASS_DEVICE -Array top_2_reduce(Array a, Array b) { - Array out; - asm volatile( - "{\n" - " .reg .v2 .f32 mx;\n" - " .reg .pred p;\n" - " max.f32 mx.x, %3, %4;\n" // max(a1, b0) - " max.f32 mx.y, %2, %5;\n" // max(a0, b1) - " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 - " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) - " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 - "}\n" : "=f"(out[0]), "=f"(out[1]) : - "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); - return out; -} - -CUTLASS_DEVICE -Array top_4_reduce_scalar(Array a, float scalar) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mx;\n" // max(a3, b) - " .reg .pred p0;\n" // a0 > b - " .reg .pred p1;\n" // a1 > b - " .reg .pred p2;\n" // a2 > b - " max.f32 mx, %7, %8;\n" // max(a3, b) - " setp.gtu.f32 p0, %4, %8;\n" // a0 > b - " setp.gtu.f32 p1, %5, %8;\n" // a1 > b - " setp.gtu.f32 p2, %6, %8;\n" // a2 > b - " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 - " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b - " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 - " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b - " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 - " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : - "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); - return out; -} - -CUTLASS_DEVICE -Array top_4_reduce(Array a, Array b) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mxa0b1;\n" // max(a0, b1) - " .reg .f32 mxa1b0;\n" // max(a1, b0) - - " .reg .f32 mxa2b0;\n" // max(a2, b0) - " .reg .f32 mxa1b1;\n" // max(a1, b1) - " .reg .f32 mxa0b2;\n" // max(a1, b1) - - " .reg .f32 mxa1b2;\n" // max(a1, b2) - " .reg .f32 mxa2b1;\n" // max(a2, b1) - " max.f32 mxa1b2, %5, %10;\n" - " max.f32 mxa2b1, %6, %9;\n" - - " .reg .f32 mxa3b0;\n" // max(a1, b2) - " .reg .f32 mxa0b3;\n" // max(a2, b1) - " max.f32 mxa3b0, %7, %8;\n" - " max.f32 mxa0b3, %4, %11;\n" - - " .reg .pred pa0b0;\n" // a0 > b0 - " .reg .pred pa1b0;\n" // a1 > b0 - " .reg .pred pa2b0;\n" // a2 > b0 - " .reg .pred pa0b1;\n" // a0 > b1 - " .reg .pred pa1b1;\n" // a1 > b1 - " .reg .pred pa0b2;\n" // a0 > b2 - " .reg .pred pb2a0;\n" // b1 > a0 - " .reg .pred pb1a0;\n" // b1 > a0 - - " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 - " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 - " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 - " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 - " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 - " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 - - " not.pred pb2a0, pa0b2;\n" - " not.pred pb1a0, pa0b1;\n" - - " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) - " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) - - " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) - " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) - " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) - - // a0 - " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 - - // a1 - " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) - - // a2 - " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case - " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 - " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 - - // a3 - " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases - " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case - " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 - " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : - "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), - "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); - return out; -} - -// Assumption: array elements are sorted in descending order -// (a[0] is the largest element in a[].) -template -CUTLASS_DEVICE -void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { - if constexpr (N == 2 && is_same_v) { - a = top_2_reduce_scalar(a, b); - } - else if constexpr (N == 4 && is_same_v) { - a = top_4_reduce_scalar(a, b); - } - else { - // slower generic path with branching, slower, and can cause register spill - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < N; ++k) { - if (a[k] < b) { - // Shift down - CUTLASS_PRAGMA_UNROLL - for (int l = N - 1; l > k; --l) { - a[l] = a[l-1]; - } - a[k] = b; - break; - } - } - } -} - -// Assumption: array elements are sorted in descending order -// (a[0] and b[0] are the largest elements in a[] and b[].) -template -CUTLASS_DEVICE -void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { - if constexpr (N == 2 && is_same_v) { - a = top_2_reduce(a, b); - } - else if constexpr (N == 4 && is_same_v) { - a = top_4_reduce(a, b); - } - else { - // slower generic path with branching, slower, and can cause register spill - int j = 0; - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < N; ++k) { - if (a[k] < b[j]) { - // Shift down - CUTLASS_PRAGMA_UNROLL - for (int l = N - 1; l > k; --l) { - a[l] = a[l-1]; - } - a[k] = b[j]; - ++j; - } - } - } -} - -// Assumption: array elements are sorted in descending order -// (a[0] is the largest element in a[].) -template -CUTLASS_DEVICE -Element topk_logsumexp(cutlass::Array a) { - // Do one less `exp`, because we know what its result will be. - // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. - // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) - // Compute m + log(1 + sum_{i != m}(x_i - x_m)) - Element sum = Element(1.0); - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < N; ++i) { - sum += fast_exp(a[i] - a[0]); - } - return a[0] + fast_log(sum); -} - -CUTLASS_DEVICE -float fast_masked_softmax(float value, float minimum, float logsumexp) { - float new_value; - asm volatile( - "{\n" - " .reg .pred p0;\n" - // value >= minimum - " setp.geu.f32 p0, %1, %2;\n" - - " .reg .f32 x_lse;\n" - " .reg .f32 %%f<11>;\n" - " .reg .b32 %%r<3>;\n" - - // x_lse = value - minimum - " sub.rn.f32 x_lse, %1, %3;\n" - - // exp(x_lse) - // The following is derived from a ptx dump of expf. - // exp requires a base conversion from exp2. - " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" - " cvt.sat.f32.f32 %%f2, %%f1;\n" - " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" - " add.f32 %%f4, %%f3, 0fCB40007F;\n" - " neg.f32 %%f5, %%f4;\n" - " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" - " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" - " mov.b32 %%r1, %%f3;\n" - " shl.b32 %%r2, %%r1, 23;\n" - " mov.b32 %%f8, %%r2;\n" - " ex2.approx.ftz.f32 %%f9, %%f7;\n" - " mul.f32 %%f10, %%f9, %%f8;\n" - - // Mask or softmax - " selp.f32 %0, %%f10, 0f00000000, p0;\n" - "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); - return new_value; -} - -template -CUTLASS_DEVICE -Element masked_softmax(Element value, Element minimum, Element logsumexp) { - if constexpr (is_same_v) { - // Inline PTX implementation - // Significantly reduces register requirements - return fast_masked_softmax(value, minimum, logsumexp); - } - else { - return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); - } -} - -} // namespace detail - -template < - int TopK, - int FragmentSize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - int Alignment = 128 / sizeof_bits_v, - bool UseButterflyReduce = true -> -struct Sm90TopKSoftmaxColReduction { -private: - static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); - static_assert(TopK == 2 || TopK == 4, - "Fused Top-K + Softmax reduction only allows K=2 and K=4, because those cases have been performance-optimized. Other values of K can be enabled by removing this assertion, but they may come with serious performance implications." - ); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - // Reduction tensors - // We have two tensors for this EVT node: a reduction tensor and a tensor holding - // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax - // require different reductions, but those luckily overlap. Top-K obviously needs at least - // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log - // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the - // maximum of all x_i elements. Since safe softmax for any element x_i is computed as - // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) - // we can track logsumexp instead of tracking two variables (sum of exps and the max). - // In addition, subtracting logsumexp from any element and taking its exp is equivalent to - // computing its softmax. - // - // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the - // way at all, because any element not in the top-K is going to be masked out and set to 0. - // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and - // keep it, and the smallest element in the top-K for masking out non-top-K elements. - // - // This means that our final reduction result will always be 2 elements, regardless of the value - // of K: minimum of top-K, and logsumexp. - // - // For each reduction tensor, we define a new struct for readability. - - struct ReductionResult { - ElementCompute min_; - ElementCompute logsumexp_; - - CUTLASS_DEVICE - ReductionResult() { } - - CUTLASS_DEVICE - ReductionResult(ElementCompute min, ElementCompute logsumexp): - logsumexp_(logsumexp), min_(min) { } - - // Warp shuffle broadcast - CUTLASS_DEVICE - void shuffle_up_sync(uint32_t delta, int lane_id) { - static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); - uint64_t r = reinterpret_cast(*this); - r = __shfl_up_sync(0xFFFFFFFF, r, delta); - *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; - } - }; - - struct TopKResult { - Array top_k_; - - CUTLASS_DEVICE - TopKResult() { - top_k_.fill(-cutlass::platform::numeric_limits::infinity()); - } - - // This is where we do the "final" reduction, where we compute - // the logsumexp for softmax, keep the smallest value in top-K, - // and discard the rest. - CUTLASS_DEVICE - ReductionResult reduce_final() const { - return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); - } - - // Butterfly reduction - CUTLASS_DEVICE - void shuffle_xor_sync(int laneMask) { - if constexpr (TopK == 2) { - static_assert(sizeof(TopKResult) == sizeof(uint64_t)); - uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); - auto synced_v = reinterpret_cast(top_k); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else if constexpr (TopK == 4) { - static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); - uint64_t* top_k_ptr = reinterpret_cast(this); - uint64_t top_k_arr[2]; - top_k_arr[0] = top_k_ptr[0]; - top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); - top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); - auto synced_v = reinterpret_cast(top_k_arr); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else { - TopKResult synced_v; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); - } - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - } - - // Warp shuffle reduction - CUTLASS_DEVICE - void shuffle_down_sync(uint32_t delta) { - if constexpr (TopK == 2) { - static_assert(sizeof(TopKResult) == sizeof(uint64_t)); - uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); - auto synced_v = reinterpret_cast(top_k); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else if constexpr (TopK == 4) { - static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); - uint64_t* top_k_ptr = reinterpret_cast(this); - uint64_t top_k_arr[2]; - top_k_arr[0] = top_k_ptr[0]; - top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); - top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); - auto synced_v = reinterpret_cast(top_k_arr); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else { - TopKResult synced_v; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); - } - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - } - }; - -public: - struct SharedStorage { }; - - struct Arguments { }; - - struct Params { }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return {}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto [M, N, K, L] = problem_shape; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - // Cross CTA reduction is not possible because there is no guarantee that all CTAs run - // concurrently. - // Cross epilogue tile reduction is possible, but re-visiting and applying reduction - // to accumulators is only possible for the current epilogue tile. - auto [epi_M, epi_N] = EpilogueTile{}; - return N <= tile_N && N <= epi_N && N >= TopK; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90TopKSoftmaxColReduction() { } - - CUTLASS_HOST_DEVICE - Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); - if (elem_less(thread_crd, residue_tCcCol)) { - TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); - detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); - } - } - - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { - return; - } - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, - // in order to reduce along modes in the `R2S` sublayout that correspond to N. - // This means we should modify and warp-reduce them according to their co-domain instead of - // their domain. Therefore we keep a filtered view of both and use them as necessary. - auto tCrTopK_f = filter(tCrTopK); - auto tCrSoftmax_f = filter(tCrSoftmax); - - // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the - // last element of Top-K, use the latter to mask the visited results, and the former - // to apply softmax. - // - // This gives us two options: reduce the Top-K with warp shuffles, have the reduced - // lanes compute logsumexp and pair it with the last Top-K element, and broadcast - // the result back using warp shuffles. - // - // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes - // compute their own logsumexp and skip the broadcast. - if constexpr (UseButterflyReduce) { - // - // 1. Butterfly reduction - // - CUTLASS_PRAGMA_UNROLL - for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrTopK_f); ++i) { - tCrTopK_f(i).shuffle_xor_sync(j); - } - } - - // - // 2. Strip down reduced value and compute sum of exps - // - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); - } - } - else { - // - // 1. Warp shuffle reduction - // - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrTopK_f); ++i) { - tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); - } - } - - // - // 2. Strip down reduced value and compute sum of exps - // - bool is_reduced_lane = get<1>(lane_mn) == 0; - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); - } - } - - // - // 3. Broadcast reduced values to all participants - // - CUTLASS_PRAGMA_UNROLL - for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); - } - } - } - - // - // 4. Re-visit and apply top-K and softmax - // - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { - auto& visit_frag = visit_results(epi_v); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - visit_frag[i] = detail::masked_softmax( - visit_frag[i], - tCrSoftmax(epi_v * FragmentSize + i).min_, - tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ - ); - } - } - - } - - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - - // Reset reduced top-K values for next tile - // This must be done because we only assume a single epilogue tile across N, - // but not M. - fill(tCrTopK, TopKResult()); - } - - CUTLASS_DEVICE void - end() { } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - - // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. - static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); - - // Reduction layout - // We're assuming all elements in a row (over which we're performing the reduction) are - // visited in the same corresponding epilogue tile, and this is what allows us to apply the - // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. - // - // This presents a challenge, because the layout of the accumulated results is typically in - // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). - // This means that we still need to reduce this tensor along N. - // - // The solution is simple: we need to flatten the layout, identify modes that correspond to - // N and set their strides to 0, in order to map fragment indices corresponding to the same - // row back to the same element in the tensor. - // - // This requires some extra layout manipulation, which is as follows. - - // Create new accumulator layout with column broadcast - auto [M, N, K] = args.tile_shape_mnk; - auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); - auto gColReduce = make_tensor( - make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) - auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) - thr_mma.partition_C(gColReduce).layout()); - - // Tile the new accumulator tensor according to R2S - ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); - Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) - auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) - - // Compose the new accumulator R2S layout with the expected tCrC layout to get final - // reduction tensor layout. - auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) - - Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) - Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) - fill(tCrTopK, TopKResult()); - - auto args_tuple = make_tuple( - cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, - lane_layout_MN, lane_mn, - args.residue_cD, args.residue_tCcD); - return ConsumerStoreCallbacks(std::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h b/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h deleted file mode 100644 index 8412b5037b3aacbca4d28b80b99839acb368d5df..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h +++ /dev/null @@ -1,914 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief This extends the contents of cutlass/functional.h with frequently used activation functions. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/constants.h" -#include "cutlass/complex.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/functional.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// If kIsHeavy is a member, use it. Otherwise, assume that it's false. -template -struct kIsHeavy_member_or_false { - static constexpr bool value = false; -}; -template -struct kIsHeavy_member_or_false::type> { - static constexpr bool value = Op::kIsHeavy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Identity operator -template -struct Identity { - static const bool kIsHeavy = false; - - CUTLASS_HOST_DEVICE - T operator()(T value) const { - return value; - } -}; - -template -struct Identity > { - CUTLASS_HOST_DEVICE - Array operator()(Array value) const { - return value; - } -}; - -/// Scale operator -template -struct Scale { - struct Arguments { - using scale_type = T; - T scale = T(1); - }; - - CUTLASS_HOST_DEVICE - T operator()(T value, T scale) const { - multiplies mul; - return mul(scale, value); - } - - CUTLASS_HOST_DEVICE - T operator()(T value, Arguments args = Arguments()) const { - return this->operator()(value, args.scale); - } -}; - -template -struct Scale> { - using Arguments = typename Scale::Arguments; - - CUTLASS_HOST_DEVICE - Array operator()(Array values, T scale) const { - multiplies> mul; - return mul(scale, values); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array values, Arguments args = Arguments()) const { - return this->operator()(values, args.scale); - } -}; - -/// Specialization to compose other activations with a defined unary operator -/// e.g. Scale> -template