File size: 4,433 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#pragma once

#include <c10/core/Device.h>
#include <c10/util/Exception.h>

#include <c10/core/Stream.h>
#include <c10/util/Registry.h>

#include <c10/core/Allocator.h>

#include <c10/util/python_stub.h>
#include <ATen/detail/AcceleratorHooksInterface.h>

#include <string>
namespace at {
class Context;
}

namespace at {
constexpr const char* MTIA_HELP =
    "The MTIA backend requires MTIA extension for PyTorch;"
    "this error has occurred because you are trying "
    "to use some MTIA's functionality without MTIA extension included.";

struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
// this fails the implementation if MTIAHooks functions are called, but
// MTIA backend is not present.
#define FAIL_MTIAHOOKS_FUNC(func) \
  TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");

  ~MTIAHooksInterface() override = default;

  void init() const override {
    // Avoid logging here, since MTIA needs init devices first then it will know
    // how many devices are available. Make it as no-op if mtia extension is not
    // dynamically loaded.
    return;
  }

  virtual bool hasMTIA() const {
    return false;
  }

  DeviceIndex deviceCount() const override {
    return 0;
  }

  virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  virtual std::string showConfig() const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  bool hasPrimaryContext(DeviceIndex /*device_index*/) const override {
    return false;
  }

  void setCurrentDevice(DeviceIndex /*device*/) const override {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  DeviceIndex getCurrentDevice() const override {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return -1;
  }

  DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return -1;
  }

  DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return -1;
  }

  virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  }

  virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return -1;
  }

  virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  }

  virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  bool isPinnedPtr(const void* /*data*/) const override {
    return false;
  }

  Allocator* getPinnedMemoryAllocator() const override {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return nullptr;
  }

  virtual PyObject* memoryStats(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return nullptr;
  }

  virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return nullptr;
  }

  virtual PyObject* getDeviceProperties(DeviceIndex device) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return nullptr;
  }

  virtual void emptyCache() const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }


  virtual void recordMemoryHistory(

    const std::optional<std::string>& /*enabled*/,

    const std::string& /*stacks*/,

    size_t /*max_entries*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  virtual PyObject* memorySnapshot(const std::optional<std::string>& local_path) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return nullptr;
  }

  virtual DeviceIndex getDeviceCount() const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return 0;
  }

  virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
  }

  virtual void attachOutOfMemoryObserver(PyObject* observer) const {
    FAIL_MTIAHOOKS_FUNC(__func__);
    return;
  }
};

struct TORCH_API MTIAHooksArgs {};

TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
#define REGISTER_MTIA_HOOKS(clsname) \
  C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)

namespace detail {
TORCH_API const MTIAHooksInterface& getMTIAHooks();
TORCH_API bool isMTIAHooksBuilt();
} // namespace detail
} // namespace at