|
|
#pragma once |
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
|
|
|
|
|
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))) |
|
|
#define AT_USE_CUSPARSE_GENERIC_API() 1 |
|
|
#else |
|
|
#define AT_USE_CUSPARSE_GENERIC_API() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(USE_ROCM) && ROCM_VERSION >= 50200 |
|
|
#define AT_USE_HIPSPARSE_GENERIC_52_API() 1 |
|
|
#else |
|
|
#define AT_USE_HIPSPARSE_GENERIC_52_API() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(USE_ROCM) && ROCM_VERSION >= 50100 |
|
|
#define AT_USE_HIPSPARSE_GENERIC_API() 1 |
|
|
#else |
|
|
#define AT_USE_HIPSPARSE_GENERIC_API() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) |
|
|
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1 |
|
|
#else |
|
|
#define AT_USE_CUSPARSE_GENERIC_SPSV() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600) |
|
|
#define AT_USE_CUSPARSE_GENERIC_SPSM() 1 |
|
|
#else |
|
|
#define AT_USE_CUSPARSE_GENERIC_SPSM() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400) |
|
|
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1 |
|
|
#else |
|
|
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0 |
|
|
#endif |
|
|
|
|
|
|
|
|
#if defined(CUDART_VERSION) || \ |
|
|
(defined(USE_ROCM) && ROCM_VERSION >= 40500 ) |
|
|
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1 |
|
|
#else |
|
|
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0 |
|
|
#endif |
|
|
|