File size: 12,818 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
#pragma once
#include <cublas_v2.h>
#include <cusparse.h>
#include <c10/macros/Export.h>
#if !defined(USE_ROCM)
#include <cusolver_common.h>
#else
#include <hipsolver/hipsolver.h>
#endif
#if defined(USE_CUDSS)
#include <cudss.h>
#endif
#include <ATen/Context.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
namespace c10 {
class CuDNNError : public c10::Error {
using Error::Error;
};
} // namespace c10
#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
do { \
auto error_object = EXPR; \
if (!error_object.is_good()) { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN Frontend error: ", error_object.get_message()); \
} \
} while (0) \
#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
// See Note [CHECK macro]
#define AT_CUDNN_CHECK(EXPR, ...) \
do { \
cudnnStatus_t status = EXPR; \
if (status != CUDNN_STATUS_SUCCESS) { \
if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN error: ", \
cudnnGetErrorString(status), \
". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
} else { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
} \
} \
} while (0)
namespace at::cuda::blas {
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
} // namespace at::cuda::blas
#define TORCH_CUDABLAS_CHECK(EXPR) \
do { \
cublasStatus_t __err = EXPR; \
TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
"CUDA error: ", \
at::cuda::blas::_cublasGetErrorEnum(__err), \
" when calling `" #EXPR "`"); \
} while (0)
const char *cusparseGetErrorString(cusparseStatus_t status);
#define TORCH_CUDASPARSE_CHECK(EXPR) \
do { \
cusparseStatus_t __err = EXPR; \
TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
"CUDA error: ", \
cusparseGetErrorString(__err), \
" when calling `" #EXPR "`"); \
} while (0)
#if defined(USE_CUDSS)
namespace at::cuda::cudss {
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
} // namespace at::cuda::solver
#define TORCH_CUDSS_CHECK(EXPR) \
do { \
cudssStatus_t __err = EXPR; \
if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
TORCH_CHECK_LINALG( \
false, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ");\
} else { \
TORCH_CHECK( \
__err == CUDSS_STATUS_SUCCESS, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`. "); \
} \
} while (0)
#else
#define TORCH_CUDSS_CHECK(EXPR) EXPR
#endif
namespace at::cuda::solver {
#if !defined(USE_ROCM)
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
constexpr const char* _cusolver_backend_suggestion = \
"If you keep seeing this error, you may use " \
"`torch.backends.cuda.preferred_linalg_library()` to try " \
"linear algebra operators with other supported backends. " \
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
#define TORCH_CUSOLVER_CHECK(EXPR) \
do { \
cusolverStatus_t __err = EXPR; \
if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \
TORCH_CHECK_LINALG( \
false, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} else { \
TORCH_CHECK( \
__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} \
} while (0)
#else // defined(USE_ROCM)
C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status);
constexpr const char* _hipsolver_backend_suggestion = \
"If you keep seeing this error, you may use " \
"`torch.backends.cuda.preferred_linalg_library()` to try " \
"linear algebra operators with other supported backends. " \
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
#define TORCH_CUSOLVER_CHECK(EXPR) \
do { \
hipsolverStatus_t __err = EXPR; \
if (__err == HIPSOLVER_STATUS_INVALID_VALUE) { \
TORCH_CHECK_LINALG( \
false, \
"hipsolver error: ", \
at::cuda::solver::hipsolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ", \
at::cuda::solver::_hipsolver_backend_suggestion); \
} else { \
TORCH_CHECK( \
__err == HIPSOLVER_STATUS_SUCCESS, \
"hipsolver error: ", \
at::cuda::solver::hipsolverGetErrorMessage(__err), \
", when calling `" #EXPR "`. ", \
at::cuda::solver::_hipsolver_backend_suggestion); \
} \
} while (0)
#endif
} // namespace at::cuda::solver
#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
// For CUDA Driver API
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#if !defined(USE_ROCM)
#define AT_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
[[maybe_unused]] CUresult get_error_str_err = \
at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
TORCH_CHECK(false, "CUDA driver error: unknown error"); \
} else { \
TORCH_CHECK(false, "CUDA driver error: ", err_str); \
} \
} \
} while (0)
#else
#define AT_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
TORCH_CHECK(false, "CUDA driver error: ", static_cast<int>(__err)); \
} \
} while (0)
#endif
// For CUDA NVRTC
//
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
// incorrectly produces the error string "NVRTC unknown error."
// The following maps it correctly.
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#define AT_CUDA_NVRTC_CHECK(EXPR) \
do { \
nvrtcResult __err = EXPR; \
if (__err != NVRTC_SUCCESS) { \
if (static_cast<int>(__err) != 7) { \
TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
} else { \
TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
} \
} \
} while (0)
|