| | #pragma once |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <ATen/cuda/CUDAContext.h> |
| | #include <ATen/OpMathType.h> |
| |
|
| | namespace at { |
| | namespace cuda { |
| | namespace blas { |
| |
|
| | |
| | |
| | class PointerModeGuard { |
| | public: |
| | PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) : |
| | handle(handle) { |
| | TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode)); |
| | TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode)); |
| | } |
| |
|
| | ~PointerModeGuard() { |
| | cublasSetPointerMode(handle, previous_mode); |
| | } |
| |
|
| | private: |
| | cublasHandle_t handle; |
| | cublasPointerMode_t previous_mode; |
| | }; |
| |
|
| | |
| |
|
| | #define CUDABLAS_GEMM_ARGTYPES(Dtype) \ |
| | char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \ |
| | const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\ |
| | Dtype *c, int64_t ldc |
| |
|
| | template <typename Dtype> |
| | inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) { |
| | AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)); |
| | template <> |
| | void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)); |
| | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
| | template <> |
| | void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)); |
| | #endif |
| | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
| | template <> |
| | void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)); |
| | #endif |
| | template <> |
| | void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)); |
| | #if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| | template <> |
| | void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); |
| | #endif |
| |
|
| | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER) |
| | enum GEMMAndBiasActivationEpilogue { |
| | None, |
| | RELU, |
| | GELU, |
| | }; |
| |
|
| | |
| | |
| | template <typename Dtype> |
| | void gemm_and_bias( |
| | bool transpose_mat1, |
| | bool transpose_mat2, |
| | int64_t m, |
| | int64_t n, |
| | int64_t k, |
| | at::opmath_type<Dtype> alpha_val, |
| | const Dtype* mat1_ptr, |
| | int64_t mat1_ld, |
| | const Dtype* mat2_ptr, |
| | int64_t mat2_ld, |
| | const Dtype* bias, |
| | Dtype* result_ptr, |
| | int64_t result_ld, |
| | GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None); |
| | #endif |
| |
|
| | #define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ |
| | char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \ |
| | const Dtype *a, int64_t lda, int64_t stridea, \ |
| | const Dtype *b, int64_t ldb, int64_t strideb, \ |
| | at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches |
| |
|
| | template <typename Dtype> |
| | inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { |
| | AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)); |
| | template <> |
| | void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)); |
| | template <> |
| | void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)); |
| | template <> |
| | void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)); |
| | template <> |
| | void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)); |
| | #if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| | template <> |
| | void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); |
| | #endif |
| |
|
| | #define CUDABLAS_TRSM_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ |
| | cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ |
| | const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb |
| |
|
| | template <typename Dtype> |
| | inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) { |
| | TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)); |
| |
|
| | #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ |
| | cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ |
| | const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \ |
| | int batchCount |
| |
|
| | template <typename Dtype> |
| | inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) { |
| | TORCH_INTERNAL_ASSERT( |
| | false, |
| | "at::cuda::blas::trsmBatched: not implemented for ", |
| | typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)); |
| | template <> |
| | TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)); |
| |
|
| | |
| |
|
| | #define CUDABLAS_GEMV_ARGTYPES(Dtype) \ |
| | char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \ |
| | const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy |
| |
|
| | template <typename Dtype> |
| | inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) { |
| | AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)); |
| | template <> |
| | void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)); |
| | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
| | template <> |
| | void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)); |
| | template <> |
| | void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)); |
| | #endif |
| | template <> |
| | void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)); |
| | #if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| | template <> |
| | void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); |
| | #endif |
| |
|
| | |
| |
|
| | #define CUDABLAS_DOT_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \ |
| | int incy, Dtype *result |
| |
|
| | template <typename Dtype> |
| | inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) { |
| | AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | void dot<double>(CUDABLAS_DOT_ARGTYPES(double)); |
| | template <> |
| | void dot<float>(CUDABLAS_DOT_ARGTYPES(float)); |
| | template <> |
| | void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)); |
| | template <> |
| | void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)); |
| | template <> |
| | void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)); |
| | template <> |
| | void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)); |
| |
|
| | template <typename Dtype> |
| | inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) { |
| | AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template <> |
| | void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)); |
| | template <> |
| | void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)); |
| |
|
| | |
| | #ifdef CUDART_VERSION |
| |
|
| | #define CUDABLAS_GETRS_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, cublasOperation_t trans, \ |
| | int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \ |
| | Dtype** dB_array, int ldb, int* info_array, int batchsize |
| |
|
| | template<class Dtype> |
| | void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { |
| | TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ", |
| | typeid(Dtype).name()); |
| | } |
| | template<> |
| | TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)); |
| |
|
| | #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \ |
| | Dtype **tau_array, int *info, int batchsize |
| |
|
| | template <class Dtype> |
| | void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { |
| | TORCH_INTERNAL_ASSERT( |
| | false, |
| | "at::cuda::blas::geqrfBatched: not implemented for ", |
| | typeid(Dtype).name()); |
| | } |
| | template <> |
| | TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)); |
| | template <> |
| | TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)); |
| | template <> |
| | TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>( |
| | CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)); |
| | template <> |
| | TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>( |
| | CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)); |
| |
|
| | #define CUDABLAS_GETRF_ARGTYPES(Dtype) \ |
| | int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize |
| |
|
| | template<class Dtype> |
| | void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { |
| | TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name()); |
| | } |
| | template<> |
| | TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>)); |
| | template<> |
| | TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>)); |
| |
|
| | #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \ |
| | cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize |
| |
|
| | template <class Dtype> |
| | void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { |
| | TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name()); |
| | } |
| |
|
| | template<> |
| | TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)); |
| | template<> |
| | TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)); |
| | template<> |
| | TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)); |
| | template<> |
| | TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)); |
| |
|
| | #endif |
| |
|
| | } |
| | } |
| | } |
| |
|