| | #pragma once |
| |
|
| | #include <c10/core/Allocator.h> |
| | #include <ATen/core/Generator.h> |
| | #include <c10/util/Exception.h> |
| |
|
| | #include <c10/util/Registry.h> |
| |
|
| | #include <cstddef> |
| | #include <functional> |
| | #include <memory> |
| |
|
| | namespace at { |
| | class Context; |
| | } |
| |
|
| | |
| | namespace at { |
| |
|
| | |
| | |
| | |
| | |
| | struct TORCH_API HIPHooksInterface { |
| | |
| | |
| | virtual ~HIPHooksInterface() {} |
| |
|
| | |
| | virtual void initHIP() const { |
| | AT_ERROR("Cannot initialize HIP without ATen_hip library."); |
| | } |
| |
|
| | virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const { |
| | AT_ERROR("Cannot initialize HIP generator without ATen_hip library."); |
| | } |
| |
|
| | virtual bool hasHIP() const { |
| | return false; |
| | } |
| |
|
| | virtual int64_t current_device() const { |
| | return -1; |
| | } |
| |
|
| | virtual Allocator* getPinnedMemoryAllocator() const { |
| | AT_ERROR("Pinned memory requires HIP."); |
| | } |
| |
|
| | virtual void registerHIPTypes(Context*) const { |
| | AT_ERROR("Cannot registerHIPTypes() without ATen_hip library."); |
| | } |
| |
|
| | virtual int getNumGPUs() const { |
| | return 0; |
| | } |
| | }; |
| |
|
| | |
| | |
| | struct TORCH_API HIPHooksArgs {}; |
| |
|
| | C10_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); |
| | #define REGISTER_HIP_HOOKS(clsname) \ |
| | C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) |
| |
|
| | namespace detail { |
| | TORCH_API const HIPHooksInterface& getHIPHooks(); |
| |
|
| | } |
| | } |
| |
|