File size: 12,818 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#pragma once

#include <cublas_v2.h>
#include <cusparse.h>
#include <c10/macros/Export.h>

#if !defined(USE_ROCM)
#include <cusolver_common.h>
#else
#include <hipsolver/hipsolver.h>
#endif

#if defined(USE_CUDSS)
#include <cudss.h>
#endif

#include <ATen/Context.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>


namespace c10 {

class CuDNNError : public c10::Error {
  using Error::Error;
};

}  // namespace c10

#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...)                                                      \
  do {                                                                                          \
    auto error_object = EXPR;                                                                   \
    if (!error_object.is_good()) {                                                              \
      TORCH_CHECK_WITH(CuDNNError, false,                                                       \
            "cuDNN Frontend error: ", error_object.get_message());                              \
    }                                                                                           \
  } while (0)                                                                                   \

#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)

// See Note [CHECK macro]
#define AT_CUDNN_CHECK(EXPR, ...)                                                               \
  do {                                                                                          \
    cudnnStatus_t status = EXPR;                                                                \
    if (status != CUDNN_STATUS_SUCCESS) {                                                       \
      if (status == CUDNN_STATUS_NOT_SUPPORTED) {                                               \
        TORCH_CHECK_WITH(CuDNNError, false,                                                     \
            "cuDNN error: ",                                                                    \
            cudnnGetErrorString(status),                                                        \
            ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
      } else {                                                                                  \
        TORCH_CHECK_WITH(CuDNNError, false,                                                     \
            "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__);                       \
      }                                                                                         \
    }                                                                                           \
  } while (0)

namespace at::cuda::blas {
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
} // namespace at::cuda::blas

#define TORCH_CUDABLAS_CHECK(EXPR)                              \
  do {                                                          \
    cublasStatus_t __err = EXPR;                                \
    TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS,                 \
                "CUDA error: ",                                 \
                at::cuda::blas::_cublasGetErrorEnum(__err),     \
                " when calling `" #EXPR "`");                   \
  } while (0)

const char *cusparseGetErrorString(cusparseStatus_t status);

#define TORCH_CUDASPARSE_CHECK(EXPR)                            \
  do {                                                          \
    cusparseStatus_t __err = EXPR;                              \
    TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS,               \
                "CUDA error: ",                                 \
                cusparseGetErrorString(__err),                  \
                " when calling `" #EXPR "`");                   \
  } while (0)

#if defined(USE_CUDSS)
namespace at::cuda::cudss {
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
} // namespace at::cuda::solver

#define TORCH_CUDSS_CHECK(EXPR)                                         \
  do {                                                                  \
    cudssStatus_t __err = EXPR;                                         \
    if (__err == CUDSS_STATUS_EXECUTION_FAILED) {                       \
      TORCH_CHECK_LINALG(                                               \
          false,                                                        \
          "cudss error: ",                                              \
          at::cuda::cudss::cudssGetErrorMessage(__err),                 \
          ", when calling `" #EXPR "`",                                 \
          ". This error may appear if the input matrix contains NaN. ");\
    } else {                                                            \
      TORCH_CHECK(                                                      \
          __err == CUDSS_STATUS_SUCCESS,                                \
          "cudss error: ",                                              \
          at::cuda::cudss::cudssGetErrorMessage(__err),                 \
          ", when calling `" #EXPR "`. ");                              \
    }                                                                   \
  } while (0)
#else
#define TORCH_CUDSS_CHECK(EXPR) EXPR
#endif

namespace at::cuda::solver {
#if !defined(USE_ROCM)

C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);

constexpr const char* _cusolver_backend_suggestion =            \
  "If you keep seeing this error, you may use "                 \
  "`torch.backends.cuda.preferred_linalg_library()` to try "    \
  "linear algebra operators with other supported backends. "    \
  "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";

// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
#define TORCH_CUSOLVER_CHECK(EXPR)                                      \
  do {                                                                  \
    cusolverStatus_t __err = EXPR;                                      \
    if (__err == CUSOLVER_STATUS_INVALID_VALUE) {                       \
      TORCH_CHECK_LINALG(                                               \
          false,                                                        \
          "cusolver error: ",                                           \
          at::cuda::solver::cusolverGetErrorMessage(__err),             \
          ", when calling `" #EXPR "`",                                 \
          ". This error may appear if the input matrix contains NaN. ", \
          at::cuda::solver::_cusolver_backend_suggestion);              \
    } else {                                                            \
      TORCH_CHECK(                                                      \
          __err == CUSOLVER_STATUS_SUCCESS,                             \
          "cusolver error: ",                                           \
          at::cuda::solver::cusolverGetErrorMessage(__err),             \
          ", when calling `" #EXPR "`. ",                               \
          at::cuda::solver::_cusolver_backend_suggestion);              \
    }                                                                   \
  } while (0)

#else // defined(USE_ROCM)

C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status);

constexpr const char* _hipsolver_backend_suggestion =           \
  "If you keep seeing this error, you may use "                 \
  "`torch.backends.cuda.preferred_linalg_library()` to try "    \
  "linear algebra operators with other supported backends. "    \
  "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";

#define TORCH_CUSOLVER_CHECK(EXPR)                                      \
  do {                                                                  \
    hipsolverStatus_t __err = EXPR;                                     \
    if (__err == HIPSOLVER_STATUS_INVALID_VALUE) {                      \
      TORCH_CHECK_LINALG(                                               \
          false,                                                        \
          "hipsolver error: ",                                          \
          at::cuda::solver::hipsolverGetErrorMessage(__err),            \
          ", when calling `" #EXPR "`",                                 \
          ". This error may appear if the input matrix contains NaN. ", \
          at::cuda::solver::_hipsolver_backend_suggestion);             \
    } else {                                                            \
      TORCH_CHECK(                                                      \
          __err == HIPSOLVER_STATUS_SUCCESS,                            \
          "hipsolver error: ",                                          \
          at::cuda::solver::hipsolverGetErrorMessage(__err),            \
          ", when calling `" #EXPR "`. ",                               \
          at::cuda::solver::_hipsolver_backend_suggestion);             \
    }                                                                   \
  } while (0)
#endif
} // namespace at::cuda::solver

#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)

// For CUDA Driver API
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#if !defined(USE_ROCM)

#define AT_CUDA_DRIVER_CHECK(EXPR)                                          \
  do {                                                                      \
    CUresult __err = EXPR;                                                  \
    if (__err != CUDA_SUCCESS) {                                            \
      const char* err_str;                                                  \
      [[maybe_unused]] CUresult get_error_str_err =                         \
          at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
      if (get_error_str_err != CUDA_SUCCESS) {                              \
        TORCH_CHECK(false, "CUDA driver error: unknown error");             \
      } else {                                                              \
        TORCH_CHECK(false, "CUDA driver error: ", err_str);                 \
      }                                                                     \
    }                                                                       \
  } while (0)

#else

#define AT_CUDA_DRIVER_CHECK(EXPR)                                                \
  do {                                                                            \
    CUresult __err = EXPR;                                                        \
    if (__err != CUDA_SUCCESS) {                                                  \
      TORCH_CHECK(false, "CUDA driver error: ", static_cast<int>(__err));                   \
    }                                                                             \
  } while (0)

#endif

// For CUDA NVRTC
//
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
// incorrectly produces the error string "NVRTC unknown error."
// The following maps it correctly.
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#define AT_CUDA_NVRTC_CHECK(EXPR)                                                                   \
  do {                                                                                              \
    nvrtcResult __err = EXPR;                                                                       \
    if (__err != NVRTC_SUCCESS) {                                                                   \
      if (static_cast<int>(__err) != 7) {                                                           \
        TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err));  \
      } else {                                                                                      \
        TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE");                        \
      }                                                                                             \
    }                                                                                               \
  } while (0)