|
|
#include <cublas_v2.h> |
|
|
#include <cuda.h> |
|
|
#include <cuda_fp16.h> |
|
|
#include <cuda_runtime.h> |
|
|
#include <torch/extension.h> |
|
|
#include <c10/cuda/CUDAGuard.h> |
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
#define CUBLAS_CHECK(condition) \ |
|
|
for (cublasStatus_t _cublas_check_status = (condition); \ |
|
|
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ |
|
|
throw std::runtime_error("cuBLAS error " + \ |
|
|
std::to_string(_cublas_check_status) + " at " + \ |
|
|
std::to_string(__LINE__)); |
|
|
|
|
|
#define CUDA_CHECK(condition) \ |
|
|
for (cudaError_t _cuda_check_status = (condition); \ |
|
|
_cuda_check_status != cudaSuccess;) \ |
|
|
throw std::runtime_error( \ |
|
|
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ |
|
|
" at " + std::to_string(__LINE__)); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { |
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); |
|
|
const auto cuda_data_type = CUDA_R_16F; |
|
|
const auto cuda_c_data_type = |
|
|
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; |
|
|
const auto compute_type = CUDA_R_32F; |
|
|
const float sp_alpha = 1.f; |
|
|
|
|
|
std::swap(a, b); |
|
|
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; |
|
|
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; |
|
|
|
|
|
|
|
|
const int m = a.size(-1); |
|
|
const int k = a.size(-2); |
|
|
const int n = b.size(-2); |
|
|
const int cublas_lda = m; |
|
|
const int cublas_ldb = k; |
|
|
const int cublas_ldc = m; |
|
|
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); |
|
|
|
|
|
#if CUDA_VERSION >= 11000 |
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; |
|
|
#else |
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; |
|
|
#endif |
|
|
const float sp_beta = 0.f; |
|
|
if (a.sizes().size() == 2 && b.sizes().size() == 2) { |
|
|
CUBLAS_CHECK(cublasGemmEx( |
|
|
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, |
|
|
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, |
|
|
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, |
|
|
compute_type, algo)); |
|
|
} else { |
|
|
|
|
|
assert(a.sizes().size() == 3 && b.sizes().size() == 3); |
|
|
|
|
|
const long long int cublas_stride_a = m * k; |
|
|
const long long int cublas_stride_b = k * n; |
|
|
const long long int cublas_stride_c = m * n; |
|
|
CUBLAS_CHECK(cublasGemmStridedBatchedEx( |
|
|
cublas_handle, cublas_trans_a, cublas_trans_b, m, |
|
|
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, |
|
|
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, |
|
|
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, |
|
|
a.size(0), compute_type, algo)); |
|
|
} |
|
|
} |
|
|
|