|
|
#pragma once |
|
|
|
|
|
#include <c10/util/Exception.h> |
|
|
#include <c10/util/Registry.h> |
|
|
|
|
|
constexpr const char* ORT_HELP = |
|
|
" You need to 'import torch_ort' to use the 'ort' device in PyTorch. " |
|
|
"The 'torch_ort' module is provided by the ONNX Runtime itself " |
|
|
"(https://onnxruntime.ai)."; |
|
|
|
|
|
|
|
|
namespace at { |
|
|
|
|
|
struct TORCH_API ORTHooksInterface { |
|
|
|
|
|
|
|
|
virtual ~ORTHooksInterface() {} |
|
|
|
|
|
virtual std::string showConfig() const { |
|
|
TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API ORTHooksArgs {}; |
|
|
|
|
|
C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs); |
|
|
#define REGISTER_ORT_HOOKS(clsname) \ |
|
|
C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname) |
|
|
|
|
|
namespace detail { |
|
|
TORCH_API const ORTHooksInterface& getORTHooks(); |
|
|
} |
|
|
|
|
|
} |
|
|
|