File size: 4,433 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <c10/util/Registry.h>
#include <c10/core/Allocator.h>
#include <c10/util/python_stub.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <string>
namespace at {
class Context;
}
namespace at {
constexpr const char* MTIA_HELP =
"The MTIA backend requires MTIA extension for PyTorch;"
"this error has occurred because you are trying "
"to use some MTIA's functionality without MTIA extension included.";
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
// this fails the implementation if MTIAHooks functions are called, but
// MTIA backend is not present.
#define FAIL_MTIAHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
~MTIAHooksInterface() override = default;
void init() const override {
// Avoid logging here, since MTIA needs init devices first then it will know
// how many devices are available. Make it as no-op if mtia extension is not
// dynamically loaded.
return;
}
virtual bool hasMTIA() const {
return false;
}
DeviceIndex deviceCount() const override {
return 0;
}
virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual std::string showConfig() const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
bool hasPrimaryContext(DeviceIndex /*device_index*/) const override {
return false;
}
void setCurrentDevice(DeviceIndex /*device*/) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
}
DeviceIndex getCurrentDevice() const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
}
virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
}
virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
bool isPinnedPtr(const void* /*data*/) const override {
return false;
}
Allocator* getPinnedMemoryAllocator() const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual PyObject* memoryStats(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual PyObject* getDeviceProperties(DeviceIndex device) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual void emptyCache() const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void recordMemoryHistory(
const std::optional<std::string>& /*enabled*/,
const std::string& /*stacks*/,
size_t /*max_entries*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual PyObject* memorySnapshot(const std::optional<std::string>& local_path) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual DeviceIndex getDeviceCount() const {
FAIL_MTIAHOOKS_FUNC(__func__);
return 0;
}
virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void attachOutOfMemoryObserver(PyObject* observer) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return;
}
};
struct TORCH_API MTIAHooksArgs {};
TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
#define REGISTER_MTIA_HOOKS(clsname) \
C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const MTIAHooksInterface& getMTIAHooks();
TORCH_API bool isMTIAHooksBuilt();
} // namespace detail
} // namespace at
|