Spaces:
Sleeping
Sleeping
| /*************************************************************************************************** | |
| * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| * SPDX-License-Identifier: BSD-3-Clause | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * | |
| * 1. Redistributions of source code must retain the above copyright notice, this | |
| * list of conditions and the following disclaimer. | |
| * | |
| * 2. Redistributions in binary form must reproduce the above copyright notice, | |
| * this list of conditions and the following disclaimer in the documentation | |
| * and/or other materials provided with the distribution. | |
| * | |
| * 3. Neither the name of the copyright holder nor the names of its | |
| * contributors may be used to endorse or promote products derived from | |
| * this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| **************************************************************************************************/ | |
| /* \file | |
| \brief Helper functions for mapping CUTLASS concepts to cuBLAS. | |
| */ | |
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |
| namespace cutlass { | |
| namespace profiler { | |
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |
| /// Converts a cuBLAS status to cutlass::Status | |
| Status get_cutlass_status(cublasStatus_t cublas); | |
| /// Converts a cuBLASS status to cutlass::profiler::Disposition | |
| Disposition get_cutlass_disposition(cublasStatus_t cublas_status); | |
| /// Maps a CUTLASS tensor layout to a cuBLAS transpose operation | |
| bool get_cublas_transpose_operation( | |
| cublasOperation_t &operation, | |
| library::LayoutTypeID layout, | |
| library::ComplexTransform transform = library::ComplexTransform::kNone); | |
| /// Maps a CUTLASS numeric type to a cuBLAS data type enumeration | |
| bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); | |
| /// Gets the cublas algorithm given threadblock tile dimensions and math opcode class | |
| cublasGemmAlgo_t get_cublas_gemm_algo( | |
| int cta_m, | |
| int cta_n, | |
| int cta_k, | |
| library::OpcodeClassID opcode_class); | |
| /// Returns a status if cuBLAS can satisfy a particular GEMM description | |
| Status cublas_satisfies(library::GemmDescription const &desc); | |
| /// Returns a status if cuBLAS can satisfy a particular RankK description | |
| Status cublas_satisfies(library::RankKDescription const &desc); | |
| /// Returns a status if cuBLAS can satisfy a particular TRMM description | |
| Status cublas_satisfies(library::TrmmDescription const &desc); | |
| /// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description | |
| Status cublas_satisfies(library::SymmDescription const &desc); | |
| /// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and | |
| /// to destroy cublasHandle_t on CublasCreate object destruction. | |
| /// Additionaly, it provides implicit cast from CublasCreate's object to cublasHandle_t's object | |
| class CublasCreate { | |
| private: | |
| cublasHandle_t handle; | |
| cublasStatus_t status; | |
| public: | |
| CublasCreate() { | |
| status = cublasCreate(&handle); | |
| } | |
| ~CublasCreate() { | |
| cublasDestroy(handle); | |
| } | |
| /// Implicit cast CublasCreate object to cublasHandle_t | |
| operator cublasHandle_t() const { return handle; } | |
| /// returns cublasStatus_t for handle creation | |
| cublasStatus_t get_cublas_create_status() { return status; } | |
| }; | |
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |
| namespace detail { | |
| /// Selects one or more cuBLAS algorithms. | |
| static void select_cublas_algorithms( | |
| std::vector<cublasGemmAlgo_t> &algorithms, | |
| Options const &options, | |
| library::GemmDescription const &op_desc) { | |
| library::OpcodeClassID const & opcode_class = | |
| op_desc.tile_description.math_instruction.opcode_class; | |
| switch (options.library.algorithm_mode) { | |
| case AlgorithmMode::kMatching: | |
| { | |
| algorithms.push_back(get_cublas_gemm_algo( | |
| op_desc.tile_description.threadblock_shape.m(), | |
| op_desc.tile_description.threadblock_shape.n(), | |
| op_desc.tile_description.threadblock_shape.k(), | |
| opcode_class)); | |
| break; | |
| } | |
| case AlgorithmMode::kBest: | |
| { | |
| // Choose first enumerated mode. If none are enumerated, choose based on opcode class | |
| // and evaluate all of them. | |
| if (options.library.algorithms.empty()) { | |
| // Enumerate all algorithms | |
| if (opcode_class == library::OpcodeClassID::kSimt) { | |
| for (int algo = CUBLAS_GEMM_DEFAULT; | |
| algo <= CUBLAS_GEMM_ALGO23; | |
| ++algo) { | |
| algorithms.push_back(cublasGemmAlgo_t(algo)); | |
| } | |
| } | |
| else { | |
| for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; | |
| algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; | |
| ++algo) { | |
| algorithms.push_back(cublasGemmAlgo_t(algo)); | |
| } | |
| } | |
| } | |
| else { | |
| // Use the listed algorithms | |
| algorithms.reserve(options.library.algorithms.size()); | |
| for (int algo : options.library.algorithms) { | |
| algorithms.push_back(reinterpret_cast<cublasGemmAlgo_t const &>(algo)); | |
| } | |
| } | |
| break; | |
| } | |
| case AlgorithmMode::kDefault: | |
| { | |
| // Use the library's default algorithm | |
| algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? | |
| CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | |
| break; | |
| } | |
| default: | |
| { | |
| break; | |
| } | |
| } | |
| } | |
| /// Dispatcher to cublasGemmEx() | |
| struct cublasGemmExDispatcher { | |
| // | |
| // Data members | |
| // | |
| library::GemmUniversalConfiguration configuration; | |
| library::GemmUniversalArguments arguments; | |
| // cublass-specific data structures to fill cublas API call arguments | |
| cublasOperation_t trans_A; | |
| cublasOperation_t trans_B; | |
| cudaDataType_t data_type_A; | |
| cudaDataType_t data_type_B; | |
| cudaDataType_t data_type_C; | |
| cudaDataType_t compute_data_type; | |
| cublasComputeType_t compute_type; | |
| cublasGemmAlgo_t algo; | |
| Status status; | |
| // | |
| // Methods | |
| // | |
| cublasGemmExDispatcher( | |
| library::GemmDescription const &op_desc, | |
| library::GemmUniversalConfiguration configuration_, | |
| library::GemmUniversalArguments arguments_, | |
| cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT | |
| ); | |
| /// Executes GEMM using these arguments | |
| cublasStatus_t operator()(cublasHandle_t handle); | |
| }; | |
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |
| /// Dispatcher to cublas rank k update kernels | |
| struct cublasRankKDispatcher { | |
| // | |
| // Data members | |
| // | |
| library::RankKConfiguration configuration; | |
| library::RankKArguments arguments; | |
| // cublass-specific data structures to fill cublas API call arguments | |
| cublasOperation_t trans_A; | |
| cublasFillMode_t uplo; | |
| cudaDataType_t data_type_A; | |
| cudaDataType_t data_type_C; | |
| cudaDataType_t compute_data_type; | |
| cublasComputeType_t compute_type; | |
| int num_ranks; //(rank-k or rank-2k) | |
| BlasMode blas_mode; //(symmetric or hermitian) | |
| Status status; | |
| // | |
| // Methods | |
| // | |
| cublasRankKDispatcher( | |
| library::RankKDescription const &op_desc, | |
| library::RankKConfiguration configuration_, | |
| library::RankKArguments arguments_ | |
| ); | |
| /// Executes RankK using these arguments | |
| cublasStatus_t operator()(cublasHandle_t handle); | |
| }; | |
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |
| /// Dispatcher to cublasTrmm() | |
| struct cublasTrmmDispatcher { | |
| // | |
| // Data members | |
| // | |
| library::TrmmConfiguration configuration; | |
| library::TrmmArguments arguments; | |
| // cublass-specific data structures to fill cublas API call arguments | |
| cublasOperation_t trans_A; | |
| cublasSideMode_t side; | |
| cublasFillMode_t uplo; | |
| cublasDiagType_t diag; | |
| cudaDataType_t data_type_A; | |
| cudaDataType_t data_type_B; | |
| cudaDataType_t data_type_D; | |
| cudaDataType_t compute_data_type; | |
| cublasComputeType_t compute_type; | |
| Status status; | |
| // | |
| // Methods | |
| // | |
| cublasTrmmDispatcher( | |
| library::TrmmDescription const &op_desc, | |
| library::TrmmConfiguration configuration_, | |
| library::TrmmArguments arguments_ | |
| ); | |
| /// Executes TRMM using these arguments | |
| cublasStatus_t operator()(cublasHandle_t handle); | |
| }; | |
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |
| /// Dispatcher to cublas symm/hemm update kernels | |
| struct cublasSymmDispatcher { | |
| // | |
| // Data members | |
| // | |
| library::SymmConfiguration configuration; | |
| library::SymmArguments arguments; | |
| // cublass-specific data structures to fill cublas API call arguments | |
| cublasSideMode_t side; | |
| cublasFillMode_t uplo; | |
| cudaDataType_t data_type_A; | |
| cudaDataType_t data_type_B; | |
| cudaDataType_t data_type_C; | |
| cudaDataType_t compute_data_type; | |
| cublasComputeType_t compute_type; | |
| BlasMode blas_mode; //(symmetric or hermitian) | |
| Status status; | |
| // | |
| // Methods | |
| // | |
| cublasSymmDispatcher( | |
| library::SymmDescription const &op_desc, | |
| library::SymmConfiguration configuration_, | |
| library::SymmArguments arguments_ | |
| ); | |
| /// Executes Symm using these arguments | |
| cublasStatus_t operator()(cublasHandle_t handle); | |
| }; | |
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |
| } // namespace detail | |
| } // namespace profiler | |
| } // namespace cutlass | |