|
|
#pragma once
|
|
|
|
|
|
|
|
|
#include <cstdint>
|
|
|
#include <map>
|
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
#include <cusparse.h>
|
|
|
#include <cublas_v2.h>
|
|
|
|
|
|
|
|
|
|
|
|
#include <cublasLt.h>
|
|
|
|
|
|
#ifdef CUDART_VERSION
|
|
|
#include <cusolverDn.h>
|
|
|
#endif
|
|
|
|
|
|
#if defined(USE_CUDSS)
|
|
|
#include <cudss.h>
|
|
|
#endif
|
|
|
|
|
|
#if defined(USE_ROCM)
|
|
|
#include <hipsolver/hipsolver.h>
|
|
|
#endif
|
|
|
|
|
|
#include <c10/core/Allocator.h>
|
|
|
#include <c10/cuda/CUDAFunctions.h>
|
|
|
|
|
|
namespace c10 {
|
|
|
struct Allocator;
|
|
|
}
|
|
|
|
|
|
namespace at::cuda {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline int64_t getNumGPUs() {
|
|
|
return c10::cuda::device_count();
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline bool is_available() {
|
|
|
return c10::cuda::device_count() > 0;
|
|
|
}
|
|
|
|
|
|
TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
|
|
|
|
|
|
TORCH_CUDA_CPP_API int warp_size();
|
|
|
|
|
|
TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
|
|
|
|
|
|
TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
|
|
|
c10::DeviceIndex device,
|
|
|
c10::DeviceIndex peer_device);
|
|
|
|
|
|
TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
|
|
|
|
|
|
|
|
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
|
|
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
|
|
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
|
|
|
|
|
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
|
|
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
|
|
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
|
|
|
|
|
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
|
|
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
|
|
#endif
|
|
|
|
|
|
#if defined(USE_CUDSS)
|
|
|
TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
|
|
|
#endif
|
|
|
|
|
|
}
|
|
|
|