| | #pragma once |
| |
|
| | #include <ATen/cuda/CUDAContext.h> |
| | #if defined(USE_ROCM) |
| | #include <hipsparse/hipsparse-version.h> |
| | #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch) |
| | #endif |
| |
|
| | |
| | |
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))) |
| | #define AT_USE_CUSPARSE_GENERIC_API() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_GENERIC_API() 0 |
| | #endif |
| |
|
| | |
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ |
| | (CUSPARSE_VERSION < 12000) |
| | #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0 |
| | #endif |
| |
|
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ |
| | (CUSPARSE_VERSION >= 12000) |
| | #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0 |
| | #endif |
| |
|
| | #if defined(USE_ROCM) |
| | |
| | #if HIPSPARSE_VERSION >= 200400 |
| | #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1 |
| | #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 |
| | #define AT_USE_HIPSPARSE_GENERIC_API() 1 |
| | #else |
| | #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 |
| | #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1 |
| | #define AT_USE_HIPSPARSE_GENERIC_API() 1 |
| | #endif |
| | #else |
| | #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 |
| | #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 |
| | #define AT_USE_HIPSPARSE_GENERIC_API() 0 |
| | #endif |
| |
|
| | |
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) |
| | #define AT_USE_CUSPARSE_GENERIC_SPSV() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_GENERIC_SPSV() 0 |
| | #endif |
| |
|
| | |
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600) |
| | #define AT_USE_CUSPARSE_GENERIC_SPSM() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_GENERIC_SPSM() 0 |
| | #endif |
| |
|
| | |
| | #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400) |
| | #define AT_USE_CUSPARSE_GENERIC_SDDMM() 1 |
| | #else |
| | #define AT_USE_CUSPARSE_GENERIC_SDDMM() 0 |
| | #endif |
| |
|
| | |
| | #if defined(CUDART_VERSION) || defined(USE_ROCM) |
| | #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1 |
| | #else |
| | #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0 |
| | #endif |
| |
|