#pragma once #include #include #include #include #include "compatibility.hpp" namespace deep_gemm { class DGException final : public std::exception { std::string message = {}; public: explicit DGException(const char *name, const char* file, const int line, const std::string& error) { message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; } const char *what() const noexcept override { return message.c_str(); } }; #ifndef DG_STATIC_ASSERT #define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) #endif #ifndef DG_HOST_ASSERT #define DG_HOST_ASSERT(cond) \ do { \ if (not (cond)) { \ throw DGException("Assertion", __FILE__, __LINE__, #cond); \ } \ } while (0) #endif #ifndef DG_HOST_UNREACHABLE #define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) #endif #ifndef DG_CUDA_DRIVER_CHECK #define DG_CUDA_DRIVER_CHECK(cmd) \ do { \ const auto& e = (cmd); \ if (e != CUDA_SUCCESS) { \ std::stringstream ss; \ const char *name, *info; \ lazy_cuGetErrorName(e, &name), lazy_cuGetErrorString(e, &info); \ ss << static_cast(e) << " (" << name << ", " << info << ")"; \ throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ } \ } while (0) #endif #ifndef DG_CUDA_RUNTIME_CHECK #define DG_CUDA_RUNTIME_CHECK(cmd) \ do { \ const auto& e = (cmd); \ if (e != cudaSuccess) { \ std::stringstream ss; \ ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \ } \ } while (0) #endif #ifndef DG_CUBLASLT_CHECK #if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE inline const char* cublasGetStatusString(cublasStatus_t status) { switch(status) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; default: return "Unknown cuBLAS error"; } } #endif #define DG_CUBLASLT_CHECK(cmd) \ do { \ const auto& e = (cmd); \ if (e != CUBLAS_STATUS_SUCCESS) { \ std::ostringstream ss; \ ss << static_cast(e) << " (" << cublasGetStatusString(e) << ")"; \ throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \ } \ } while (0) #endif } // namespace deep_gemm