| #include <cublas_v2.h> |
| #include <cuda.h> |
| #include <cuda_fp16.h> |
| #include <cuda_runtime.h> |
| #include <torch/extension.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <ATen/cuda/CUDAContext.h> |
|
|
| #define CUBLAS_CHECK(condition) \ |
| for (cublasStatus_t _cublas_check_status = (condition); \ |
| _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ |
| throw std::runtime_error("cuBLAS error " + \ |
| std::to_string(_cublas_check_status) + " at " + \ |
| std::to_string(__LINE__)); |
|
|
| #define CUDA_CHECK(condition) \ |
| for (cudaError_t _cuda_check_status = (condition); \ |
| _cuda_check_status != cudaSuccess;) \ |
| throw std::runtime_error( \ |
| "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ |
| " at " + std::to_string(__LINE__)); |
|
|
| |
| |
| |
| |
| |
| void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); |
| const auto cuda_data_type = CUDA_R_16F; |
| const auto cuda_c_data_type = |
| c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; |
| const auto compute_type = CUDA_R_32F; |
| const float sp_alpha = 1.f; |
| |
| std::swap(a, b); |
| const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; |
| const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; |
| |
| |
| const int m = a.size(-1); |
| const int k = a.size(-2); |
| const int n = b.size(-2); |
| const int cublas_lda = m; |
| const int cublas_ldb = k; |
| const int cublas_ldc = m; |
| cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); |
|
|
| #if CUDA_VERSION >= 11000 |
| cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; |
| #else |
| cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; |
| #endif |
| const float sp_beta = 0.f; |
| if (a.sizes().size() == 2 && b.sizes().size() == 2) { |
| CUBLAS_CHECK(cublasGemmEx( |
| cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, |
| a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, |
| cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, |
| compute_type, algo)); |
| } else { |
| |
| assert(a.sizes().size() == 3 && b.sizes().size() == 3); |
|
|
| const long long int cublas_stride_a = m * k; |
| const long long int cublas_stride_b = k * n; |
| const long long int cublas_stride_c = m * n; |
| CUBLAS_CHECK(cublasGemmStridedBatchedEx( |
| cublas_handle, cublas_trans_a, cublas_trans_b, m, |
| n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, |
| cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, |
| &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, |
| a.size(0), compute_type, algo)); |
| } |
| } |
|
|