cranky-coder08 commited on
Commit
c1af2fa
·
verified ·
1 Parent(s): f4cade0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc +3 -0
  3. phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc +3 -0
  4. phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc +3 -0
  5. phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc +3 -0
  6. phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc +3 -0
  7. phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc +3 -0
  8. phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd +3 -0
  9. phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h +25 -0
  10. phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h +191 -0
  11. phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h +39 -0
  12. phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h +631 -0
  13. phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h +203 -0
  14. phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h +111 -0
  15. phivenv/Lib/site-packages/torch/include/ATen/core/List.h +491 -0
  16. phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h +353 -0
  17. phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h +194 -0
  18. phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h +143 -0
  19. phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h +187 -0
  20. phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h +240 -0
  21. phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h +35 -0
  22. phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h +22 -0
  23. phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h +84 -0
  24. phivenv/Lib/site-packages/torch/include/ATen/core/Range.h +25 -0
  25. phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h +14 -0
  26. phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h +1 -0
  27. phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h +1 -0
  28. phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h +98 -0
  29. phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h +275 -0
  30. phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h +1056 -0
  31. phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h +0 -0
  32. phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h +17 -0
  33. phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h +175 -0
  34. phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h +1 -0
  35. phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h +21 -0
  36. phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h +83 -0
  37. phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h +92 -0
  38. phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h +94 -0
  39. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h +213 -0
  40. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h +106 -0
  41. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h +283 -0
  42. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h +320 -0
  43. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h +27 -0
  44. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h +38 -0
  45. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h +41 -0
  46. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h +410 -0
  47. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +785 -0
  48. phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h +140 -0
  49. phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h +67 -0
  50. phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h +279 -0
.gitattributes CHANGED
@@ -54,3 +54,10 @@ phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=
54
  phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
55
  phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
56
  phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
54
  phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
55
  phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
56
  phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
57
+ phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
58
+ phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
59
+ phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
60
+ phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
61
+ phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
62
+ phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
63
+ phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9ae9e3e39533b703fa0fb49576d02a073be55a6fbe3f9d9a38cbeb9ed03e116
3
+ size 100308
phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88b150085f0eb6dcd1c70d632b16ccf923b66e1700800a4756d06b3726b91fcf
3
+ size 132673
phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92b5449a62f76826fcde2e62b85c16f953a7ccfff8847bc4854a098b5a954dae
3
+ size 100411
phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f26434c485b7881d3ef563e57c88a171319a39cfcc3bf348cbe5bfd0d2a9887
3
+ size 201319
phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5727abd2cd4972398036f183a2e811e78ffa31946bf89c453917a171a61c12aa
3
+ size 114484
phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa833453940a5409176fe65a5ba338e66d9d875a4a905f92c064b1ade0faba66
3
+ size 140105
phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72ee579e80fb57b5b52f1a5a44b4dcbf85567e43442ad80f9da51f21e2f9977f
3
+ size 723968
phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ostream>
4
+ #include <string>
5
+
6
+ #include <c10/core/Scalar.h>
7
+ #include <ATen/core/Tensor.h>
8
+
9
+ namespace c10 {
10
+ TORCH_API std::ostream& operator<<(std::ostream& out, Backend b);
11
+ TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s);
12
+ TORCH_API std::string toString(const Scalar& s);
13
+ }
14
+ namespace at {
15
+
16
+ TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
17
+ TORCH_API std::ostream& print(
18
+ std::ostream& stream,
19
+ const Tensor& tensor,
20
+ int64_t linesize);
21
+ inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
22
+ return print(out,t,80);
23
+ }
24
+ TORCH_API void print(const Tensor & t, int64_t linesize=80);
25
+ }
phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <deque>
5
+ #include <mutex>
6
+ #include <utility>
7
+
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <c10/core/Device.h>
11
+ #include <c10/core/DispatchKeySet.h>
12
+
13
+ // For the record I don't think this is a correct pimpl idiom.
14
+ // Including Impl header in interface header defeats the purpose
15
+ // because you can't change Impl private members without forcing
16
+ // everything that included the interface to rebuild.
17
+ // Impl should be forward-declared in the interface header instead.
18
+ #include <c10/core/GeneratorImpl.h>
19
+
20
+ /**
21
+ * Note [Generator]
22
+ * ~~~~~~~~~~~~~~~~
23
+ * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
24
+ * generate a seemingly random sequence of numbers, that may be later be used in creating
25
+ * a random distribution. Such an engine almost always maintains a state and requires a
26
+ * seed to start off the creation of random numbers. Often times, users have
27
+ * found it beneficial to be able to explicitly create, retain, and destroy
28
+ * PRNG states and also be able to have control over the seed value.
29
+ *
30
+ * A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
31
+ * For instance, it does so by letting users seed a PRNG engine, fork the state of the
32
+ * engine, etc.
33
+ *
34
+ * By default, there is one generator per device, and a device's generator is
35
+ * lazily created. A user can use the torch.Generator() api to create their own generator.
36
+ */
37
+
38
+ /**
39
+ * Note [Acquire lock when using random generators]
40
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41
+ * Generator and its derived classes are NOT thread-safe. Please note that most of the
42
+ * places where we have inserted locking for generators are historically based, and we
43
+ * haven't actually checked that everything is truly thread safe (and it probably isn't).
44
+ * Please use the public mutex_ when using any methods from these classes, except for the
45
+ * read-only methods. You can learn about the usage by looking into the unittests
46
+ * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
47
+ *
48
+ * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
49
+ * them non-thread safe and instead making the generator state splittable, to accommodate
50
+ * forks into other threads).
51
+ */
52
+
53
+ namespace at {
54
+
55
+ class Tensor;
56
+
57
+ struct TORCH_API Generator {
58
+ Generator() = default;
59
+
60
+ explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
61
+ : impl_(std::move(gen_impl)) {
62
+ if (impl_.get() == nullptr) {
63
+ throw std::runtime_error("GeneratorImpl with nullptr is not supported");
64
+ }
65
+ }
66
+
67
+ bool operator==(const Generator& rhs) const {
68
+ return this->impl_ == rhs.impl_;
69
+ }
70
+
71
+ bool operator!=(const Generator& rhs) const {
72
+ return !((*this) == rhs);
73
+ }
74
+
75
+ bool defined() const {
76
+ return static_cast<bool>(impl_);
77
+ }
78
+
79
+ c10::GeneratorImpl* unsafeGetGeneratorImpl() const {
80
+ return impl_.get();
81
+ }
82
+
83
+ c10::GeneratorImpl* unsafeReleaseGeneratorImpl() {
84
+ return impl_.release();
85
+ }
86
+
87
+ const c10::intrusive_ptr<c10::GeneratorImpl>& getIntrusivePtr() const {
88
+ return impl_;
89
+ }
90
+
91
+ void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
92
+ // Sets the offset of Generator state to the desired offset. This is currently
93
+ // supported for only Philox based Generators, i.e., CUDA and MPS.
94
+ void set_offset(uint64_t offset) { impl_->set_offset(offset); }
95
+
96
+ // Returns the offset of Generator state. This is currently supported for only
97
+ // Philox based Generators, i.e., CUDA and MPS.
98
+ uint64_t get_offset() const { return impl_->get_offset(); }
99
+
100
+ uint64_t current_seed() const { return impl_->current_seed(); }
101
+
102
+ uint64_t seed() { return impl_->seed(); }
103
+
104
+ // Implementation not inlined to prevent cycle reference between
105
+ // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
106
+ void set_state(const at::Tensor& new_state);
107
+
108
+ at::Tensor get_state() const;
109
+
110
+ void graphsafe_set_state(const Generator& new_state);
111
+
112
+ Generator graphsafe_get_state() const;
113
+
114
+ std::mutex& mutex() {
115
+ return impl_->mutex_;
116
+ }
117
+
118
+ DispatchKeySet key_set() const {
119
+ return impl_->key_set();
120
+ }
121
+
122
+ Device device() const { return impl_->device(); }
123
+
124
+ inline void set_pyobj(PyObject* pyobj) const noexcept {
125
+ impl_->set_pyobj(pyobj);
126
+ }
127
+
128
+ inline PyObject* pyobj() const noexcept {
129
+ return impl_->pyobj();
130
+ }
131
+
132
+ template<typename T>
133
+ T* get() const { return static_cast<T*>(impl_.get()); }
134
+
135
+ Generator clone() const {
136
+ return Generator(impl_->clone());
137
+ }
138
+
139
+ private:
140
+ c10::intrusive_ptr<c10::GeneratorImpl> impl_;
141
+ };
142
+
143
+ template<class Impl, class... Args>
144
+ Generator make_generator(Args&&... args) {
145
+ return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
146
+ }
147
+
148
+ /**
149
+ * Utility function to static cast input Generator* to
150
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
151
+ */
152
+ template <typename T>
153
+ inline T * check_generator(std::optional<Generator> gen) {
154
+ TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
155
+ TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
156
+ TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
157
+ return gen->get<T>();
158
+ }
159
+
160
+ /**
161
+ * Utility function used in tensor implementations, which
162
+ * supplies the default generator to tensors, if an input generator
163
+ * is not supplied. The input Generator* is also static casted to
164
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
165
+ */
166
+ template <typename T>
167
+ inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
168
+ return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
169
+ }
170
+
171
+ namespace detail {
172
+
173
+ /**
174
+ * Helper function for checking the validity of new random generator
175
+ * state. Right now following conditions are checked:
176
+ *
177
+ * - The new state tensor must be a torch.ByteTensor
178
+ * - Data of the new state tensor must be contiguous
179
+ */
180
+ inline void check_rng_state(const c10::TensorImpl& new_state) {
181
+ TORCH_CHECK_TYPE(
182
+ new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
183
+ "RNG state must be a torch.ByteTensor"
184
+ );
185
+
186
+ TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
187
+ }
188
+
189
+ } // namespace detail
190
+
191
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Generator.h>
4
+ #include <c10/util/intrusive_ptr.h>
5
+
6
+ namespace at {
7
+
8
+ using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
9
+
10
+ TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
11
+
12
+ class TORCH_API _GeneratorRegister {
13
+ public:
14
+ explicit _GeneratorRegister(const GeneratorFuncType& func);
15
+ };
16
+
17
+ TORCH_API at::Generator GetGeneratorForPrivateuse1(
18
+ c10::DeviceIndex device_index);
19
+
20
+ /**
21
+ * This is used to register Generator to PyTorch for `privateuse1` key.
22
+ *
23
+ * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
24
+ *
25
+ * class CustomGeneratorImpl : public c10::GeneratorImpl {
26
+ * CustomGeneratorImpl(DeviceIndex device_index = -1);
27
+ * explicit ~CustomGeneratorImpl() override = default;
28
+ * ...
29
+ * };
30
+ *
31
+ * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
32
+ * return at::make_generator<CustomGeneratorImpl>(id);
33
+ * }
34
+ */
35
+
36
+ #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
37
+ static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
38
+
39
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue_to.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ #include <functional>
8
+ #include <initializer_list>
9
+ #include <iterator>
10
+ #include <type_traits>
11
+
12
+ /*
13
+ * [Note: IListRef]
14
+ * Wrapper around different API containers (e.g. boxed and unboxed).
15
+ *
16
+ * What is it?
17
+ * ===========
18
+ * It is a tagged union of both boxed and unboxed API containers.
19
+ * Working implementations:
20
+ *
21
+ * - `IListRef<at::Tensor>`
22
+ * - `IListRef<at::OptionalTensorRef>`
23
+ *
24
+ * Note that `IListRef` is a view type. Meaning that it won't own the
25
+ * tensors it holds. It's intended to be used only as argument parameters.
26
+ * Specifically, where these 2 worlds overlap.
27
+ *
28
+ * What is this for?
29
+ * =================
30
+ * Historically, PyTorch has maintained 2 different APIs: the unboxed
31
+ * (called from C++ API and Python eager mode) and boxed APIs (called
32
+ * from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
33
+ *
34
+ * Calling unboxed kernels from the boxed "world" and vice-versa may
35
+ * result in non-negligible overhead. Lists are one of those types:
36
+ *
37
+ * - Boxed world: `c10::List`
38
+ * - Unboxed world: `c10::ArrayRef`
39
+ *
40
+ * In this context, `c10::IListRef` solves this problem by wrapping those
41
+ * 2 container types, so that we don't need to convert from one to
42
+ * the other.
43
+ *
44
+ * (see https://github.com/pytorch/pytorch/issues/66328)
45
+ *
46
+ * What does it do?
47
+ * ================
48
+ * This container wraps around the different tagged containers
49
+ * (currently, only boxed and unboxed), without incurring in extra
50
+ * overhead for converting from one to another. It does so while
51
+ * exposing usual container methods, which dispatch to corresponding
52
+ * implementations.
53
+ *
54
+ * While it works with different container types, it introduces
55
+ * overhead for repeatedly calling member functions (since those will
56
+ * get dispatched, again). Therefore, you should only use it to iterate
57
+ * through the list up to one time. If you need to do more complex things,
58
+ * call `materialize()` first.
59
+ *
60
+ * Adding support for a new Tag
61
+ * ============================
62
+ * Suppose we want to add a new tag: `Chest`. Here are the steps
63
+ * we would have to go through:
64
+ *
65
+ * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
66
+ *
67
+ * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
68
+ * ...
69
+ * _(Chest, ##__VA_ARGS__)
70
+ *
71
+ * 2. Add type aliases, union members, and constructors.
72
+ *
73
+ * template <typename T>
74
+ * class IListRef {
75
+ * ...
76
+ * using chest_type =
77
+ * typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type;
78
+ * ...
79
+ * IListRef(...) : tag_(IListRefTag::Chest) {
80
+ * ...
81
+ * }
82
+ * ...
83
+ * union Payload {
84
+ * ...
85
+ * chest_type chest;
86
+ * ...
87
+ * };
88
+ * ...
89
+ * };
90
+ *
91
+ * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
92
+ * preferable to make the default implementation work for `T = Tensor`
93
+ * (both `Unboxed` and `Boxed` do it).
94
+ *
95
+ * template <typename T, typename ListElemT>
96
+ * class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> {
97
+ * public:
98
+ * using elem_type = ListElemT;
99
+ * using list_type = ChestContainer<elem_type>;
100
+ *
101
+ * static const list_type& unwrap(const IListRef<T>& ilist) { ... }
102
+ *
103
+ * static typename list_type::const_iterator& unwrap(
104
+ * IListRefIterator<T>& it) { ... }
105
+ *
106
+ * static const typename list_type::const_iterator& unwrap(
107
+ * const IListRefIterator<T>& it) { ... }
108
+ *
109
+ * static IListRefConstRef<T> iterator_get(
110
+ * const typename list_type::const_iterator& it) { ... }
111
+ * }
112
+ *
113
+ * 4. Add an specialization for each of the already supported types.
114
+ * Finally, for consistency, add them to the tracking list.
115
+ * (see [Note: IListRefTagImpl Specializations])
116
+ *
117
+ * template <>
118
+ * class IListRefTagImpl<IListRefTag::Chest, at::Tensor>
119
+ * : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {};
120
+ *
121
+ * Adding support for a new Type
122
+ * =============================
123
+ * Suppose we want to add support for a new type: `Matrix`.
124
+ * Here are the steps we would have to go through:
125
+ *
126
+ * 1. Add an specialization for each of the existing tags.
127
+ * For consistency, add them to the tracking list.
128
+ * (see [Note: IListRefTagImpl Specializations])
129
+ *
130
+ * template <>
131
+ * class IListRefTagImpl<IListRefTag::Unboxed, Matrix>
132
+ * : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {};
133
+ *
134
+ * template <>
135
+ * class IListRefTagImpl<Matrix, IListRefTag::Boxed>
136
+ * : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {};
137
+ *
138
+ * Common Problems
139
+ * ===============
140
+ * 1. One of `IListRef(Iterator)` methods are failing to compile.
141
+ *
142
+ * That may be happening because the container type you added
143
+ * is not compatible with the code written for that method. If
144
+ * that's true, then you might have to transform that code into
145
+ * a static method call (see `List::operator[]` method).
146
+ *
147
+ * 2. Can't make `IListRefIterator<T>::operator*` return a const-reference.
148
+ *
149
+ * First, keep in mind that we assume that boxed containers will
150
+ * have to deal with `IValue` (e.g. `c10::List`). In this context,
151
+ * what may be happening is that `IValue` doesn't store internally
152
+ * your type `T`. Instead, it constructs a type new `T` everytime
153
+ * you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
154
+ */
155
+
156
+ namespace c10 {
157
+ template <typename T>
158
+ class IListRef;
159
+
160
+ /*
161
+ * Applies arbitrary macros to each `IListRefTag`.
162
+ */
163
+ #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
164
+ _(Unboxed, ##__VA_ARGS__) \
165
+ _(Boxed, ##__VA_ARGS__) \
166
+ _(Materialized, ##__VA_ARGS__)
167
+
168
+ /*
169
+ * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
170
+ * while bringing to scope:
171
+ *
172
+ * - `ImplT`: the implementation class for `TAG`
173
+ * - `this_`: the result of unwrapping `this`
174
+ */
175
+ #define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \
176
+ case c10::IListRefTag::TAG: { \
177
+ using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \
178
+ auto& this_ = ImplT::unwrap(*this); \
179
+ BODY \
180
+ } break;
181
+
182
+ /*
183
+ * Dispatches the unwrap call, depending on `TAG`, followed by
184
+ * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
185
+ *
186
+ * This macro is useful because it allows us to handle different
187
+ * types (that correspond to different tags) to be implemented
188
+ * only once. We can do it even when the implementation of the
189
+ * different tags aren't syntatically the same, by dispatching
190
+ * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
191
+ */
192
+ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
193
+ switch (TAG) { \
194
+ TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
195
+ break; \
196
+ default: \
197
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
198
+ }
199
+
200
+ enum class IListRefTag {
201
+ #define DEFINE_TAG(tag, ...) tag,
202
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
203
+ #undef DEFINE_TAG
204
+ None
205
+ };
206
+
207
+ namespace detail {
208
+ /*
209
+ * Type alias that specifies whether we return a reference or a copy of `T`.
210
+ *
211
+ * What is this for?
212
+ * =================
213
+ * Since values in the boxed world are represented by an `IValue`, we also
214
+ * depend on whether it can be converted to a const-reference (`Tensor`) or
215
+ * has to create a new copy of `T` (`OptionalTensorRef`).
216
+ */
217
+ template <typename T>
218
+ using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type;
219
+
220
+ /*
221
+ * Interface that implements key functions for each `IListRefTag` type.
222
+ *
223
+ * What is this for?
224
+ * =================
225
+ * Given an `IListRef(Iterator)<T>`, some methods have to be implemented
226
+ * differently for each `TAG`. Therefore, the methods inside this class
227
+ * are used as dispatch targets for the different `IListRefTag` values.
228
+ *
229
+ * You should create an specialization of this class for each possible
230
+ * combination of `IListRefTag` type (except `None`) and element types
231
+ * (e.g. `Tensor`).
232
+ *
233
+ * What does it do?
234
+ * ================
235
+ * 1. defines static methods to be used as dispatch targets by both
236
+ * `IListRef<T>` and `IListRefIterator<T>` (see the implementation of
237
+ * `IListRefTagImplBase`).
238
+ *
239
+ * 2. defines the `elem_type` and `list_type` aliases that will be
240
+ * used in the definition of `IListRef<T>`. In general, we should do
241
+ * so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`.
242
+ *
243
+ * [Note: IListRefTagImpl Specialization]
244
+ * ======================================
245
+ * For `IListRef(Iterator)<at::Tensor>`:
246
+ * - <IListRefTag::Unboxed, at::Tensor>
247
+ * - <IListRefTag::Boxed, at::Tensor>
248
+ * - <IListRefTag::Materialized, at::Tensor>
249
+ *
250
+ * For `IListRef(Iterator)<at::OptionalTensorRef>`:
251
+ * - <IListRefTag::Unboxed, at::OptionalTensorRef>
252
+ * - <IListRefTag::Boxed, at::OptionalTensorRef>
253
+ * - <IListRefTag::Materialized, at::OptionalTensorRef>
254
+ */
255
+ template <IListRefTag TAG, typename T>
256
+ class IListRefTagImpl {};
257
+
258
+ /*
259
+ * Base implementation of `IListRefTagImpl<TAG, T>` methods.
260
+ *
261
+ * What is this for?
262
+ * =================
263
+ * This should make adding specializations for new types easier. For
264
+ * example, one should be able to add a new type just by making its
265
+ * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
266
+ *
267
+ * You should create a partial specialization for this class only if
268
+ * you introduce a new `IListRefTag`. The idea being that there is one
269
+ * default implementation for each possible value of `IListRefTag`.
270
+ *
271
+ * What does it do?
272
+ * ================
273
+ * 1. defines `elem_type` as an alias to `ListElemT`.
274
+ *
275
+ * 1. defines `list_type` as an alias to the default container type
276
+ * that will hold a collection of `elem_type`. The idea being that
277
+ * all types tagged as `TAG` will have `list_type` as its container,
278
+ * with different `elem_type`.
279
+ *
280
+ * 3. defines the default implementation for each of the methods that
281
+ * are supposed to be defined on `IListRefTagImpl` specializations.
282
+ *
283
+ * 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means
284
+ * that the payload of the type `IListRef<T>` will be of type `list_type`
285
+ * when it is tagged as `TAG`.
286
+ */
287
+ template <IListRefTag TAG, typename T, typename ListElemT = T>
288
+ class IListRefTagImplBase {};
289
+
290
+ /*
291
+ * Materialized container for `IListRef<T>`.
292
+ *
293
+ * What is this for?
294
+ * =================
295
+ * Container that groups `T` references together. This exchanges the
296
+ * overhead of every method call from `IListRef<T>` for a dynamic allocation.
297
+ *
298
+ * You should use this container instead of `IListRef<T>` if:
299
+ *
300
+ * - You are going to iterate the list more than once
301
+ * - You need to repeatedly access arbitrary elements (using `operator[]`)
302
+ * What does it do?
303
+
304
+ * ================
305
+ * Removes the reference (&) from the type, and wraps it into a
306
+ * `std::reference_wrapper`. If `IListRefConstRef<T>` is not a
307
+ * reference type, then it's left unchanged.
308
+ */
309
+ template <typename T>
310
+ using _MaterializedIListRefElem = std::conditional_t<
311
+ std::is_reference_v<T>,
312
+ typename std::reference_wrapper<std::remove_reference_t<T>>,
313
+ T>;
314
+
315
+ template <typename T>
316
+ using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>;
317
+
318
+ template <typename T>
319
+ using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>;
320
+
321
+ } // namespace detail
322
+
323
+ /*
324
+ * Iterator for `IListRef<T>`.
325
+ *
326
+ * What is it?
327
+ * ===========
328
+ * Currently, a `std::bidirectional_iterator` that wraps the iterator
329
+ * types defined for each of the `IListRefTag`.
330
+ *
331
+ * One should be able to use it, as if it were the unwrapped
332
+ * iterators themselves.
333
+
334
+ * What does it do?
335
+ * ================
336
+ * Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it
337
+ * wraps each container's `const_iterator` type alias. So, for example,
338
+ * given that the container for `IListRefTag::Boxed` is `c10::List`, this
339
+ * iterator will wrap a `c10::List::const_iterator`.
340
+ *
341
+ * [Note: MSVC Iterator Debug]
342
+ * ===========================
343
+ * MSVC `vector<T>::iterator` implementation (used in the boxed variant)
344
+ * makes it so this union's destructor, copy-constructor (assignment), and
345
+ * move-constructor (assignment) are implicitly deleted.
346
+ *
347
+ * Therefore, we need to explicitly define them as needed. Follows a list
348
+ * of places where these are needed and their reason:
349
+ *
350
+ * - `Payload` destructor:
351
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
352
+ *
353
+ * - `IListRefIterator` destructor:
354
+ * same as above. However, we need to explicitly call the variant
355
+ * destructor explicitly.
356
+ *
357
+ * - `IListRefIterator` copy-constructor:
358
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
359
+ * than 0.
360
+ */
361
+ template <typename T>
362
+ class IListRefIterator {
363
+ private:
364
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
365
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
366
+ friend class detail::IListRefTagImplBase< \
367
+ IListRefTag::TAG, \
368
+ T, \
369
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
370
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
371
+ #undef DEFINE_FRIEND_CLASS
372
+
373
+ public:
374
+ // C++17 friendly std::iterator implementation
375
+ using iterator_category = std::bidirectional_iterator_tag;
376
+ using value_type = T;
377
+ using difference_type = std::ptrdiff_t;
378
+ using pointer = T*;
379
+ using reference = T&;
380
+
381
+ using unboxed_iterator_type = typename detail::
382
+ IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator;
383
+ using boxed_iterator_type = typename detail::
384
+ IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator;
385
+ using materialized_iterator_type =
386
+ typename detail::MaterializedIListRef<T>::const_iterator;
387
+
388
+ IListRefIterator() : tag_(IListRefTag::None) {}
389
+
390
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
391
+ // See [Note: MSVC Iterator Debug]
392
+ IListRefIterator(const IListRefIterator& iterator)
393
+ : tag_(iterator.tag_) {
394
+ switch (tag_) {
395
+ case IListRefTag::Boxed:
396
+ payload_.boxed_iterator = iterator.payload_.boxed_iterator;
397
+ break;
398
+ case IListRefTag::Unboxed:
399
+ payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
400
+ break;
401
+ case IListRefTag::Materialized:
402
+ payload_.materialized_iterator = iterator.payload_.materialized_iterator;
403
+ break;
404
+ default:
405
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
406
+ }
407
+ }
408
+ #endif
409
+
410
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
411
+ // See [Note: MSVC Iterator Debug]
412
+ ~IListRefIterator() noexcept(false) {
413
+ switch (tag_) {
414
+ case IListRefTag::Boxed:
415
+ payload_.boxed_iterator.~boxed_iterator_type();
416
+ break;
417
+ case IListRefTag::Unboxed:
418
+ payload_.unboxed_iterator.~unboxed_iterator_type();
419
+ break;
420
+ case IListRefTag::Materialized:
421
+ payload_.materialized_iterator.~materialized_iterator_type();
422
+ break;
423
+ default:
424
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
425
+ }
426
+ }
427
+ #endif
428
+
429
+ IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
430
+ payload_.boxed_iterator = boxed;
431
+ }
432
+
433
+ IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
434
+ payload_.unboxed_iterator = unboxed;
435
+ }
436
+
437
+ IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
438
+ payload_.materialized_iterator = materialized;
439
+ }
440
+
441
+ detail::IListRefConstRef<T> operator*() const {
442
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
443
+ }
444
+
445
+ IListRefIterator& operator++() {
446
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
447
+ return *this;
448
+ }
449
+
450
+ IListRefIterator operator++(int) {
451
+ auto old = *this;
452
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
453
+ return old;
454
+ }
455
+
456
+ IListRefIterator& operator--() {
457
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
458
+ return *this;
459
+ }
460
+
461
+ IListRefIterator operator--(int) {
462
+ auto old = *this;
463
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
464
+ return old;
465
+ }
466
+
467
+ bool operator==(const IListRefIterator& rhs) const {
468
+ if (tag_ != rhs.tag_) {
469
+ return false;
470
+ }
471
+ TORCH_ILISTREF_UNWRAP(tag_, {
472
+ auto& rhs_it = ImplT::unwrap(rhs);
473
+ return this_ == rhs_it;
474
+ });
475
+ }
476
+
477
+ bool operator!=(const IListRefIterator& rhs) const {
478
+ return !(*this == rhs);
479
+ }
480
+
481
+ private:
482
+ union Payload {
483
+ boxed_iterator_type boxed_iterator;
484
+ unboxed_iterator_type unboxed_iterator;
485
+ materialized_iterator_type materialized_iterator;
486
+ void* _init_ptr;
487
+ Payload() : _init_ptr(nullptr) {}
488
+ #if defined(_MSC_VER)
489
+ // See [Note: MSVC Iterator Debug]
490
+ ~Payload() {}
491
+ #endif
492
+ };
493
+
494
+ Payload payload_;
495
+ IListRefTag tag_;
496
+ };
497
+
498
+ /*
499
+ * See [Note: IListRef]
500
+ */
501
+ template <typename T>
502
+ class IListRef {
503
+ private:
504
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
505
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
506
+ friend class detail::IListRefTagImplBase< \
507
+ IListRefTag::TAG, \
508
+ T, \
509
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
510
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
511
+ #undef DEFINE_FRIEND_CLASS
512
+
513
+ public:
514
+ using unboxed_type =
515
+ typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type;
516
+ using boxed_type =
517
+ typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type;
518
+ using materialized_type =
519
+ typename detail::MaterializedIListRef<T>;
520
+
521
+ using iterator = IListRefIterator<T>;
522
+ using const_iterator = IListRefIterator<T>;
523
+ using reverse_iterator = std::reverse_iterator<iterator>;
524
+ using value_type = typename iterator::value_type;
525
+
526
+ IListRef() : tag_(IListRefTag::None) {}
527
+
528
+ IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
529
+ payload_.boxed = &boxed;
530
+ }
531
+
532
+ IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
533
+ payload_.unboxed = unboxed;
534
+ }
535
+
536
+ IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) {
537
+ payload_.unboxed = at::ArrayRef<T>(list);
538
+ }
539
+
540
+ template <
541
+ typename... UnboxedConstructorArgs,
542
+ typename = std::enable_if_t<
543
+ std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>>
544
+ IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
545
+ payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...);
546
+ }
547
+
548
+ IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
549
+ payload_.materialized = &materialized;
550
+ }
551
+
552
+ size_t size() const {
553
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
554
+ }
555
+
556
+ bool empty() const {
557
+ return size() == 0;
558
+ }
559
+
560
+ iterator begin() const {
561
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
562
+ }
563
+
564
+ iterator end() const {
565
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
566
+ }
567
+
568
+ detail::IListRefConstRef<T> front() const {
569
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
570
+ }
571
+
572
+ /*
573
+ * Materializes the `IListRef` into a `std::vector`.
574
+ *
575
+ * This should be used when one wishes to either:
576
+ *
577
+ * - iterate over the list more than once: each `IListRefIterator`
578
+ * member function call has to go through a switch, introducing
579
+ * non-negligible overhead
580
+ *
581
+ * - randomly access an arbitrary element using `operator[]`:
582
+ * same reason as above
583
+ */
584
+ detail::MaterializedIListRef<T> materialize() const {
585
+ if (isMaterialized()) {
586
+ return toMaterialized();
587
+ }
588
+
589
+ detail::MaterializedIListRef<T> materialized;
590
+ materialized.reserve(size());
591
+ for (const auto& t : *this) {
592
+ materialized.emplace_back(t);
593
+ }
594
+ return materialized;
595
+ }
596
+
597
+ #define DEFINE_CHECK(TAG, ...) \
598
+ bool is##TAG() const { \
599
+ return tag_ == IListRefTag::TAG; \
600
+ }
601
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK)
602
+ #undef DEFINE_CHECK
603
+
604
+ bool isNone() const {
605
+ return tag_ == IListRefTag::None;
606
+ }
607
+
608
+ #define DEFINE_CASTING(TAG, ...) \
609
+ const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \
610
+ to##TAG() const { \
611
+ TORCH_INTERNAL_ASSERT(is##TAG()); \
612
+ return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this); \
613
+ }
614
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING)
615
+ #undef DEFINE_CASTING
616
+
617
+ private:
618
+ union Payload {
619
+ const boxed_type* boxed;
620
+ unboxed_type unboxed;
621
+ const materialized_type* materialized;
622
+ Payload() : boxed(nullptr) {}
623
+ };
624
+
625
+ Payload payload_;
626
+ IListRefTag tag_;
627
+ };
628
+
629
+ } // namespace c10
630
+
631
+ #include <ATen/core/IListRef_inl.h>
phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/List.h>
4
+ #include <ATen/core/Tensor.h>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+ class OptionalTensorRef;
9
+ }
10
+
11
+
12
+ namespace c10::detail {
13
+
14
+ /*
15
+ * Specializations of `IListRefTagImplBase` that implement the default
16
+ * implementation for `IListRefTag::Unboxed`.
17
+ */
18
+ template <typename T, typename ListElemT>
19
+ class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
20
+ public:
21
+ using elem_type = ListElemT;
22
+ using list_type = ArrayRef<elem_type>;
23
+
24
+ /*
25
+ * These `unwrap` static methods unwraps the inner containers out
26
+ * of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
27
+ * the macro `TORCH_ILISTREF_UNWRAP` is called.
28
+ */
29
+ static const list_type& unwrap(const IListRef<T>& ilist) {
30
+ return ilist.payload_.unboxed;
31
+ }
32
+
33
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
34
+ return it.payload_.unboxed_iterator;
35
+ }
36
+
37
+ static const typename list_type::const_iterator& unwrap(
38
+ const IListRefIterator<T>& it) {
39
+ return it.payload_.unboxed_iterator;
40
+ }
41
+
42
+ /*
43
+ * We have these function (besides the `unwrap`s above) because the
44
+ * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
45
+ * weren't syntatically equal for the existing tags at the time
46
+ * (`Unboxed` and `Boxed`).
47
+ */
48
+ static IListRefConstRef<T> front(const list_type& lst) {
49
+ return lst.front();
50
+ }
51
+
52
+ static IListRefConstRef<T> iterator_get(
53
+ const typename list_type::const_iterator& it) {
54
+ return *it;
55
+ }
56
+ };
57
+
58
+ /*
59
+ * Specializations of `IListRefTagImplBase` that implement the default
60
+ * implementation for `IListRefTag::Boxed`.
61
+ */
62
+ template <typename T, typename ListElemT>
63
+ class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
64
+ public:
65
+ using elem_type = ListElemT;
66
+ using list_type = List<elem_type>;
67
+
68
+ static const list_type& unwrap(const IListRef<T>& ilist) {
69
+ return *ilist.payload_.boxed;
70
+ }
71
+
72
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
73
+ return it.payload_.boxed_iterator;
74
+ }
75
+
76
+ static const typename list_type::const_iterator& unwrap(
77
+ const IListRefIterator<T>& it) {
78
+ return it.payload_.boxed_iterator;
79
+ }
80
+
81
+ static IListRefConstRef<T> front(const list_type& lst) {
82
+ return lst[0];
83
+ }
84
+
85
+ static IListRefConstRef<T> iterator_get(
86
+ const typename list_type::const_iterator& it) {
87
+ return (*it).get().toTensor();
88
+ }
89
+ };
90
+
91
+ /*
92
+ * Specializations of `IListRefTagImplBase` that implement the default
93
+ * implementation for `IListRefTag::Materialized`.
94
+ */
95
+ template <typename T>
96
+ class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
97
+ public:
98
+ using elem_type = MaterializedIListRefElem<T>;
99
+ using list_type = MaterializedIListRef<T>;
100
+
101
+ static const list_type& unwrap(const IListRef<T>& ilist) {
102
+ return *ilist.payload_.materialized;
103
+ }
104
+
105
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
106
+ return it.payload_.materialized_iterator;
107
+ }
108
+
109
+ static const typename list_type::const_iterator& unwrap(
110
+ const IListRefIterator<T>& it) {
111
+ return it.payload_.materialized_iterator;
112
+ }
113
+
114
+ static IListRefConstRef<T> front(const list_type& lst) {
115
+ return lst[0];
116
+ }
117
+
118
+ static IListRefConstRef<T> iterator_get(
119
+ const typename list_type::const_iterator& it) {
120
+ return *it;
121
+ }
122
+ };
123
+
124
+ /*
125
+ * [Note: ITensorListRef]
126
+ * Specializations necessary for `IListRef<at::Tensor>` type.
127
+ *
128
+ * Since the default implementations are usually done with supporting
129
+ * `Tensor` in mind, we only have to inherit from the base implementations.
130
+ */
131
+ template <>
132
+ class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
133
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
134
+
135
+ template <>
136
+ class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
137
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
138
+
139
+ template <>
140
+ class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
141
+ : public IListRefTagImplBase<
142
+ IListRefTag::Materialized,
143
+ at::Tensor,
144
+ MaterializedIListRefElem<at::Tensor>> {};
145
+
146
+ /*
147
+ * [Note: IOptTensorListRef]
148
+ * Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
149
+ *
150
+ * We can't get an `at::OptionalTensorRef` directly from an instance of
151
+ * `List<optional<Tensor>>` (the type that corresponds to the boxed world).
152
+ *
153
+ * So, the default implementation won't help us. Thus, we have to implement
154
+ * this method ourselves.
155
+ */
156
+ template <>
157
+ class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
158
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
159
+
160
+ template <>
161
+ class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
162
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> {
163
+
164
+ public:
165
+ /*
166
+ * Given an instance of the types corresponding to the `Boxed` tag, we override
167
+ * the default implementation, so that we can return a `at::OptionalTensorRef`.
168
+ */
169
+ static IListRefConstRef<at::OptionalTensorRef> iterator_get(
170
+ const typename list_type::const_iterator& it) {
171
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference")
172
+ const auto& ivalue = (*it).get();
173
+ C10_DIAGNOSTIC_POP()
174
+ if (!ivalue.isNone()) {
175
+ const auto& tensor = ivalue.toTensor();
176
+ return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
177
+ }
178
+ return {};
179
+ }
180
+ };
181
+
182
+ template <>
183
+ class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
184
+ : public IListRefTagImplBase<
185
+ IListRefTag::Materialized,
186
+ at::OptionalTensorRef,
187
+ MaterializedIListRefElem<at::OptionalTensorRef>> {};
188
+
189
+ } // namespace c10::detail
190
+
191
+
192
+ namespace at {
193
+
194
+ // [Note: ITensorListRef]
195
+ using ITensorListRef = c10::IListRef<at::Tensor>;
196
+ using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
197
+ using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
198
+ // [Note: IOptTensorListRef]
199
+ using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
200
+ using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
201
+ using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
202
+
203
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // The legacy mechanism for dispatching operators in ATen is a Type
4
+ // object, which is essentially a giant virtual dispatch table
5
+ // for every operation we support dynamically dispatching over.
6
+ //
7
+ // This has been deprecated in favor of ATenDispatch, and in the future,
8
+ // c10 dispatcher.
9
+ // TODO: Clean up what remains here
10
+
11
+ #include <c10/core/impl/LocalDispatchKeySet.h>
12
+
13
+ namespace at {
14
+
15
+ // A RAII, thread local (!) guard that will disable dispatch to variable
16
+ // handler.
17
+ //
18
+ // NOTE [ Treating Variables as non-Variables in type dispatch ]
19
+ //
20
+ // What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes
21
+ // dispatches on ATen functions to go to the non-variable implementation,
22
+ // bypassing autograd handling (and also profiling and tracing).
23
+ //
24
+ // To understand why this guard exists, it's helpful to understand the history
25
+ // behind how Variable was implemented. Previously, Variables were implemented
26
+ // as a wrapper on Tensors; so the act of processing a Variable involved
27
+ // unwrapping the underlying Tensor, and then calling the underlying base
28
+ // operation on /that/ operation
29
+ //
30
+ // However, after the Variable/Tensor merge, there is no concept of unwrapping
31
+ // a tensor anymore. If you just call the operation on the same variable
32
+ // again inside your VariableType handler, you'll dispatch back to
33
+ // VariableType, which is not what we want.
34
+ //
35
+ // The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which
36
+ // when enabled will cause `legacyTensorType()` and `getType()` to always return
37
+ // non-Variable type, even if the tensor being called on is a variable.
38
+
39
+ /* Note [AutoDispatchBelowAutograd]
40
+ * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used
41
+ * for kernel implementations and customized C++ kernels.
42
+ * If you are looking for a guard to run workload in inference mode, please use
43
+ * c10::InferenceMode RAII which is user facing API.
44
+ * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode)
45
+ * was used in the user code for inference-only workload, this was under risk of
46
+ * producing wrong results silently in some edge cases. For example:
47
+ * ```
48
+ * torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
49
+ * torch::Tensor out = s * s;
50
+ * {
51
+ * at::AutoDispatchBelowAutograd guard;
52
+ * s.add_(1); // Skips version bump on `s`.
53
+ * }
54
+ * // WRONG GRADIENT! s.grad() are now computed using `s` value after the
55
+ * // inplace update.
56
+ * out.backward(torch::ones_like(out));
57
+ * ```
58
+ * Users should use `c10::InferenceMode` here so that it'll properly throw an
59
+ * error saying "one of the variables needed for gradient computation has be modified."
60
+ */
61
+ struct TORCH_API AutoDispatchBelowAutograd {
62
+ AutoDispatchBelowAutograd() :
63
+ autograd_guard_(c10::autograd_dispatch_keyset) {
64
+ }
65
+
66
+ // disable all autograd dispatch keys
67
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
68
+ };
69
+
70
+ // TODO: AutoNonVariableTypeMode should be removed in release 1.10.
71
+ struct TORCH_API AutoNonVariableTypeMode {
72
+ AutoNonVariableTypeMode(bool enabled = true) :
73
+ autograd_guard_(c10::autograd_dispatch_keyset) {
74
+ TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. "
75
+ "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, "
76
+ "If you are looking for a user facing API to enable running your inference-only "
77
+ "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code "
78
+ "is under risk of producing silent wrong result in some edge cases. "
79
+ "See Note [AutoDispatchBelowAutograd] for more details.");
80
+ TORCH_INTERNAL_ASSERT(enabled);
81
+ }
82
+
83
+ // disable all autograd dispatch keys
84
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
85
+ };
86
+
87
+ struct TORCH_API AutoDispatchSkipFunctionalize {
88
+ AutoDispatchSkipFunctionalize() :
89
+ dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) {
90
+ }
91
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
92
+ };
93
+
94
+ /* Note [AutoDispatchBelowADInplaceOrView]
95
+ * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
96
+ * before we split inplace & view ops out of VariableType kernel.
97
+ * Note this guard is used in VariableType kernels for functional ops
98
+ * as well as ADInplaceOrView kernels for inplace/view ops to enforce the
99
+ * Invariant:
100
+ * Once you are in VariableType/ADInplaceOrView kernel for an op,
101
+ * you never go back to a kernel on same dispatch key until
102
+ * you finish the current op.
103
+ */
104
+ struct TORCH_API AutoDispatchBelowADInplaceOrView {
105
+ AutoDispatchBelowADInplaceOrView() :
106
+ dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) {
107
+ }
108
+ // disable Autograd & ADInplaceOrView dispatch keys
109
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
110
+ };
111
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/List.h ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue_to.h>
4
+ #include <ATen/core/jit_type_base.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/macros/Export.h>
7
+ #include <c10/util/TypeTraits.h>
8
+ #include <c10/util/TypeList.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <c10/util/ArrayRef.h>
11
+ #include <optional>
12
+ #include <vector>
13
+
14
+ namespace at {
15
+ class Tensor;
16
+ }
17
+ namespace c10 {
18
+ struct IValue;
19
+ template<class T> class List;
20
+ struct Type;
21
+
22
+ namespace detail {
23
+
24
+ struct ListImpl final : public c10::intrusive_ptr_target {
25
+ using list_type = std::vector<IValue>;
26
+
27
+ explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
28
+
29
+ list_type list;
30
+
31
+ TypePtr elementType;
32
+
33
+ intrusive_ptr<ListImpl> copy() const {
34
+ return make_intrusive<ListImpl>(list, elementType);
35
+ }
36
+ friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
37
+ };
38
+ }
39
+
40
+ namespace impl {
41
+
42
+ template<class T, class Iterator> class ListIterator;
43
+
44
+ template<class T, class Iterator> class ListElementReference;
45
+
46
+ template<class T, class Iterator>
47
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
48
+
49
+ template<class T, class Iterator>
50
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
51
+
52
+ template<class T, class Iterator>
53
+ bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
54
+
55
+ template<class T>
56
+ struct ListElementConstReferenceTraits {
57
+ // In the general case, we use IValue::to().
58
+ using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
59
+ };
60
+
61
+ // There is no to() overload for std::optional<std::string>.
62
+ template<>
63
+ struct ListElementConstReferenceTraits<std::optional<std::string>> {
64
+ using const_reference = std::optional<std::reference_wrapper<const std::string>>;
65
+ };
66
+
67
+ template<class T, class Iterator>
68
+ class ListElementReference final {
69
+ public:
70
+ operator std::conditional_t<
71
+ std::is_reference_v<typename c10::detail::
72
+ ivalue_to_const_ref_overload_return<T>::type>,
73
+ const T&,
74
+ T>() const;
75
+
76
+ ListElementReference& operator=(T&& new_value) &&;
77
+
78
+ ListElementReference& operator=(const T& new_value) &&;
79
+
80
+ // assigning another ref to this assigns the underlying value
81
+ ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
82
+
83
+ const IValue& get() const& {
84
+ return *iterator_;
85
+ }
86
+
87
+ friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
88
+
89
+ ListElementReference(const ListElementReference&) = delete;
90
+ ListElementReference& operator=(const ListElementReference&) = delete;
91
+ ~ListElementReference() = default;
92
+
93
+ private:
94
+ ListElementReference(Iterator iter)
95
+ : iterator_(iter) {}
96
+
97
+ // allow moving, but only our friends (i.e. the List class) can move us
98
+ ListElementReference(ListElementReference&&) noexcept = default;
99
+ ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
100
+ iterator_ = std::move(rhs.iterator_);
101
+ return *this;
102
+ }
103
+
104
+ friend class List<T>;
105
+ friend class ListIterator<T, Iterator>;
106
+
107
+ Iterator iterator_;
108
+ };
109
+
110
+ // this wraps vector::iterator to make sure user code can't rely
111
+ // on it being the type of the underlying vector.
112
+ template <class T, class Iterator>
113
+ class ListIterator final {
114
+ public:
115
+ // C++17 friendly std::iterator implementation
116
+ using iterator_category = std::random_access_iterator_tag;
117
+ using value_type = T;
118
+ using difference_type = std::ptrdiff_t;
119
+ using pointer = T*;
120
+ using reference = ListElementReference<T, Iterator>;
121
+
122
+ explicit ListIterator() = default;
123
+ ~ListIterator() = default;
124
+
125
+ ListIterator(const ListIterator&) = default;
126
+ ListIterator(ListIterator&&) noexcept = default;
127
+ ListIterator& operator=(const ListIterator&) = default;
128
+ ListIterator& operator=(ListIterator&&) noexcept = default;
129
+
130
+ ListIterator& operator++() {
131
+ ++iterator_;
132
+ return *this;
133
+ }
134
+
135
+ ListIterator operator++(int) {
136
+ ListIterator copy(*this);
137
+ ++*this;
138
+ return copy;
139
+ }
140
+
141
+ ListIterator& operator--() {
142
+ --iterator_;
143
+ return *this;
144
+ }
145
+
146
+ ListIterator operator--(int) {
147
+ ListIterator copy(*this);
148
+ --*this;
149
+ return copy;
150
+ }
151
+
152
+ ListIterator& operator+=(typename List<T>::size_type offset) {
153
+ iterator_ += offset;
154
+ return *this;
155
+ }
156
+
157
+ ListIterator& operator-=(typename List<T>::size_type offset) {
158
+ iterator_ -= offset;
159
+ return *this;
160
+ }
161
+
162
+ ListIterator operator+(typename List<T>::size_type offset) const {
163
+ return ListIterator{iterator_ + offset};
164
+ }
165
+
166
+ ListIterator operator-(typename List<T>::size_type offset) const {
167
+ return ListIterator{iterator_ - offset};
168
+ }
169
+
170
+ friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
171
+ return lhs.iterator_ - rhs.iterator_;
172
+ }
173
+
174
+ ListElementReference<T, Iterator> operator*() const {
175
+ return {iterator_};
176
+ }
177
+
178
+ ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
179
+ return {iterator_ + offset};
180
+ }
181
+
182
+ private:
183
+ explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
184
+
185
+ Iterator iterator_;
186
+
187
+ friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
188
+ return lhs.iterator_ == rhs.iterator_;
189
+ }
190
+
191
+ friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
192
+ return !(lhs == rhs);
193
+ }
194
+
195
+ friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
196
+ return lhs.iterator_ < rhs.iterator_;
197
+ }
198
+
199
+ friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
200
+ return lhs.iterator_ <= rhs.iterator_;
201
+ }
202
+
203
+ friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
204
+ return lhs.iterator_ > rhs.iterator_;
205
+ }
206
+
207
+ friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
208
+ return lhs.iterator_ >= rhs.iterator_;
209
+ }
210
+
211
+ friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
212
+ friend class List<T>;
213
+ };
214
+
215
+ template<class T> List<T> toTypedList(List<IValue> list);
216
+ template<class T> List<IValue> toList(List<T>&& list);
217
+ template<class T> List<IValue> toList(const List<T>& list);
218
+ const IValue* ptr_to_first_element(const List<IValue>& list);
219
+ }
220
+
221
+ /**
222
+ * An object of this class stores a list of values of type T.
223
+ *
224
+ * This is a pointer type. After a copy, both Lists
225
+ * will share the same storage:
226
+ *
227
+ * > List<int> a;
228
+ * > List<int> b = a;
229
+ * > b.push_back("three");
230
+ * > ASSERT("three" == a.get(0));
231
+ *
232
+ * We use this class in the PyTorch kernel API instead of
233
+ * std::vector<T>, because that allows us to do optimizations
234
+ * and switch out the underlying list implementation without
235
+ * breaking backwards compatibility for the kernel API.
236
+ */
237
+ template<class T>
238
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
239
+ class List final {
240
+ private:
241
+ // This is an intrusive_ptr because List is a pointer type.
242
+ // Invariant: This will never be a nullptr, there will always be a valid
243
+ // ListImpl.
244
+ c10::intrusive_ptr<c10::detail::ListImpl> impl_;
245
+
246
+ using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
247
+ using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
248
+
249
+ public:
250
+ using value_type = T;
251
+ using size_type = typename c10::detail::ListImpl::list_type::size_type;
252
+ using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
253
+ using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
254
+ using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
255
+
256
+ /**
257
+ * Constructs an empty list.
258
+ */
259
+ explicit List();
260
+
261
+ /**
262
+ * Constructs a list with some initial values.
263
+ * Example:
264
+ * List<int> a({2, 3, 4});
265
+ */
266
+ List(std::initializer_list<T> initial_values);
267
+ explicit List(ArrayRef<T> initial_values);
268
+
269
+ /**
270
+ * Create a generic list with runtime type information.
271
+ * This only works for c10::impl::GenericList and is not part of the public API
272
+ * but only supposed to be used internally by PyTorch.
273
+ */
274
+ explicit List(TypePtr elementType);
275
+
276
+ List(const List&) = default;
277
+ List& operator=(const List&) = default;
278
+ ~List() = default;
279
+
280
+ /**
281
+ * Create a new List pointing to a deep copy of the same data.
282
+ * The List returned is a new list with separate storage.
283
+ * Changes in it are not reflected in the original list or vice versa.
284
+ */
285
+ List copy() const;
286
+
287
+ /**
288
+ * Returns the element at specified location pos, with bounds checking.
289
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
290
+ */
291
+ internal_const_reference_type get(size_type pos) const;
292
+
293
+ /**
294
+ * Moves out the element at the specified location pos and returns it, with bounds checking.
295
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
296
+ * The list contains an invalid element at position pos afterwards. Any operations
297
+ * on it before re-setting it are invalid.
298
+ */
299
+ value_type extract(size_type pos) const;
300
+
301
+ /**
302
+ * Returns a reference to the element at specified location pos, with bounds checking.
303
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
304
+ *
305
+ * You cannot store the reference, but you can read it and assign new values to it:
306
+ *
307
+ * List<int64_t> list = ...;
308
+ * list[2] = 5;
309
+ * int64_t v = list[1];
310
+ */
311
+ internal_const_reference_type operator[](size_type pos) const;
312
+
313
+ internal_reference_type operator[](size_type pos);
314
+
315
+ /**
316
+ * Assigns a new value to the element at location pos.
317
+ */
318
+ void set(size_type pos, const value_type& value) const;
319
+
320
+ /**
321
+ * Assigns a new value to the element at location pos.
322
+ */
323
+ void set(size_type pos, value_type&& value) const;
324
+
325
+ /**
326
+ * Returns an iterator to the first element of the container.
327
+ * If the container is empty, the returned iterator will be equal to end().
328
+ */
329
+ iterator begin() const;
330
+
331
+ /**
332
+ * Returns an iterator to the element following the last element of the container.
333
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
334
+ */
335
+ iterator end() const;
336
+
337
+ /**
338
+ * Checks if the container has no elements.
339
+ */
340
+ bool empty() const;
341
+
342
+ /**
343
+ * Returns the number of elements in the container
344
+ */
345
+ size_type size() const;
346
+
347
+ /**
348
+ * Increase the capacity of the vector to a value that's greater or equal to new_cap.
349
+ */
350
+ void reserve(size_type new_cap) const;
351
+
352
+ /**
353
+ * Erases all elements from the container. After this call, size() returns zero.
354
+ * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
355
+ */
356
+ void clear() const;
357
+
358
+ /**
359
+ * Inserts value before pos.
360
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
361
+ */
362
+ iterator insert(iterator pos, const T& value) const;
363
+
364
+ /**
365
+ * Inserts value before pos.
366
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
367
+ */
368
+ iterator insert(iterator pos, T&& value) const;
369
+
370
+ /**
371
+ * Inserts a new element into the container directly before pos.
372
+ * The new element is constructed with the given arguments.
373
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
374
+ */
375
+ template<class... Args>
376
+ iterator emplace(iterator pos, Args&&... value) const;
377
+
378
+ /**
379
+ * Appends the given element value to the end of the container.
380
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
381
+ */
382
+ void push_back(const T& value) const;
383
+
384
+ /**
385
+ * Appends the given element value to the end of the container.
386
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
387
+ */
388
+ void push_back(T&& value) const;
389
+
390
+ /**
391
+ * Appends the given list to the end of the container. Uses at most one memory allocation.
392
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
393
+ */
394
+ void append(List<T> lst) const;
395
+
396
+ /**
397
+ * Appends the given element value to the end of the container.
398
+ * The new element is constructed with the given arguments.
399
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
400
+ */
401
+ template<class... Args>
402
+ void emplace_back(Args&&... args) const;
403
+
404
+ /**
405
+ * Removes the element at pos.
406
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
407
+ */
408
+ iterator erase(iterator pos) const;
409
+
410
+ /**
411
+ * Removes the elements in the range [first, last).
412
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
413
+ */
414
+ iterator erase(iterator first, iterator last) const;
415
+
416
+ /**
417
+ * Removes the last element of the container.
418
+ * Calling pop_back on an empty container is undefined.
419
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
420
+ */
421
+ void pop_back() const;
422
+
423
+ /**
424
+ * Resizes the container to contain count elements.
425
+ * If the current size is less than count, additional default-inserted elements are appended.
426
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
427
+ */
428
+ void resize(size_type count) const;
429
+
430
+ /**
431
+ * Resizes the container to contain count elements.
432
+ * If the current size is less than count, additional copies of value are appended.
433
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
434
+ */
435
+ void resize(size_type count, const T& value) const;
436
+
437
+ /**
438
+ * Value equality comparison. This function implements Python-like semantics for
439
+ * equality: two lists with the same identity (e.g. same pointer) trivially
440
+ * compare equal, otherwise each element is compared for equality.
441
+ */
442
+ template <class T_>
443
+ friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
444
+
445
+ template <class T_>
446
+ friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
447
+
448
+ /**
449
+ * Identity comparison. Returns true if and only if `rhs` represents the same
450
+ * List object as `this`.
451
+ */
452
+ bool is(const List<T>& rhs) const;
453
+
454
+ std::vector<T> vec() const;
455
+
456
+ /**
457
+ * Returns the number of Lists currently pointing to this same list.
458
+ * If this is the only instance pointing to this list, returns 1.
459
+ */
460
+ // TODO Test use_count
461
+ size_t use_count() const;
462
+
463
+ TypePtr elementType() const;
464
+
465
+ // See [unsafe set type] for why this exists.
466
+ void unsafeSetElementType(TypePtr t);
467
+
468
+ private:
469
+ explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
470
+ explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
471
+ friend struct IValue;
472
+ template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
473
+ template<class T_> friend List<IValue> impl::toList(List<T_>&&);
474
+ template<class T_> friend List<IValue> impl::toList(const List<T_>&);
475
+ friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
476
+ };
477
+
478
+ namespace impl {
479
+ // GenericList is how IValue stores lists. It is, however, not part of the
480
+ // public API. Kernels should use Lists with concrete types instead
481
+ // (maybe except for some internal prim ops).
482
+ using GenericList = List<IValue>;
483
+
484
+ }
485
+ }
486
+
487
+ namespace torch {
488
+ template<class T> using List = c10::List<T>;
489
+ }
490
+
491
+ #include <ATen/core/List_inl.h> // IWYU pragma: keep
phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/jit_type_base.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ namespace c10 {
7
+
8
+ template<class T> decltype(auto) getTypePtr();
9
+ std::string toString(const Type& type);
10
+
11
+ template<class T>
12
+ List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
13
+ : impl_(std::move(elements)) {}
14
+
15
+ template<class T>
16
+ List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
17
+ : impl_(elements) {}
18
+
19
+ template<class T>
20
+ List<T>::List()
21
+ : List(make_intrusive<c10::detail::ListImpl>(
22
+ typename c10::detail::ListImpl::list_type(),
23
+ getTypePtr<T>())) {
24
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
25
+ }
26
+
27
+ template<class T>
28
+ List<T>::List(ArrayRef<T> values)
29
+ : List(make_intrusive<c10::detail::ListImpl>(
30
+ typename c10::detail::ListImpl::list_type(),
31
+ getTypePtr<T>())) {
32
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
33
+ impl_->list.reserve(values.size());
34
+ for (const T& element : values) {
35
+ impl_->list.push_back(element);
36
+ }
37
+ }
38
+
39
+ template<class T>
40
+ List<T>::List(std::initializer_list<T> initial_values)
41
+ : List(ArrayRef<T>(initial_values)) {
42
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
43
+ }
44
+
45
+ template<class T>
46
+ List<T>::List(TypePtr elementType)
47
+ : List(make_intrusive<c10::detail::ListImpl>(
48
+ typename c10::detail::ListImpl::list_type(),
49
+ std::move(elementType))) {
50
+ static_assert(std::is_same_v<T, IValue> || std::is_same_v<T, c10::intrusive_ptr<ivalue::Future>>,
51
+ "This constructor is only valid for c10::impl::GenericList or List<Future>.");
52
+ }
53
+
54
+ namespace impl {
55
+ template<class T>
56
+ List<T> toTypedList(impl::GenericList list) {
57
+ // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
58
+ // because upcasting would allow people to add types into the new list that would break the old list.
59
+ // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
60
+ // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
61
+ // without having to copy it. This is also used to provide backwards compatibility with some old models
62
+ // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
63
+ // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
64
+ // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
65
+ TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66
+ || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
67
+ , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
68
+ return List<T>(std::move(list.impl_));
69
+ }
70
+
71
+ template<class T>
72
+ impl::GenericList toList(List<T>&& list) {
73
+ return GenericList(std::move(list.impl_));
74
+ }
75
+ template<class T>
76
+ impl::GenericList toList(const List<T>& list) {
77
+ return GenericList(list.impl_);
78
+ }
79
+ }
80
+
81
+ template<class T>
82
+ List<T> List<T>::copy() const {
83
+ return List<T>(impl_->copy());
84
+ }
85
+
86
+ namespace detail {
87
+ template<class T>
88
+ T list_element_to(T element) {
89
+ return element;
90
+ }
91
+ template<class T>
92
+ T list_element_to(const IValue& element) {
93
+ return element.template to<T>();
94
+ }
95
+ template<class T>
96
+ T list_element_to(IValue&& element) {
97
+ return std::move(element).template to<T>();
98
+ }
99
+ template<class T>
100
+ struct ListElementFrom {
101
+ static IValue from(const T& element) {
102
+ return element;
103
+ }
104
+ static IValue from(T&& element) {
105
+ return std::move(element);
106
+ }
107
+ };
108
+ template<>
109
+ struct ListElementFrom<IValue> {
110
+ static const IValue& from(const IValue& element) {
111
+ return element;
112
+ }
113
+ static IValue&& from(IValue&& element) {
114
+ return std::move(element);
115
+ }
116
+ };
117
+ }
118
+
119
+ namespace impl {
120
+
121
+ template <class T, class Iterator>
122
+ ListElementReference<T, Iterator>::operator std::conditional_t<
123
+ std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
124
+ T>::type>,
125
+ const T&,
126
+ T>() const {
127
+ return iterator_->template to<T>();
128
+ }
129
+
130
+ template<class T, class Iterator>
131
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
132
+ *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
133
+ return *this;
134
+ }
135
+
136
+ template<class T, class Iterator>
137
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
138
+ *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
139
+ return *this;
140
+ }
141
+
142
+ template<class T, class Iterator>
143
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
144
+ *iterator_ = *rhs.iterator_;
145
+ return *this;
146
+ }
147
+
148
+ template<class T, class Iterator>
149
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
150
+ std::swap(*lhs.iterator_, *rhs.iterator_);
151
+ }
152
+
153
+ template<class T, class Iterator>
154
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
155
+ const T& lhs_tmp = lhs;
156
+ return lhs_tmp == rhs;
157
+ }
158
+
159
+ template<class T, class Iterator>
160
+ inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
161
+ return rhs == lhs;
162
+ }
163
+
164
+ template<class T>
165
+ inline typename ListElementConstReferenceTraits<T>::const_reference
166
+ list_element_to_const_ref(const IValue& element) {
167
+ return element.template to<T>();
168
+ }
169
+
170
+ template<>
171
+ inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
172
+ list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
173
+ return element.toOptionalStringRef();
174
+ }
175
+
176
+ } // namespace impl
177
+
178
+ template<class T>
179
+ void List<T>::set(size_type pos, const value_type& value) const {
180
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
181
+ }
182
+
183
+ template<class T>
184
+ void List<T>::set(size_type pos, value_type&& value) const {
185
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
186
+ }
187
+
188
+ template<class T>
189
+ typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
190
+ return operator[](pos);
191
+ }
192
+
193
+ template<class T>
194
+ typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
195
+ return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
196
+ }
197
+
198
+ template<class T>
199
+ typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
200
+ static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
201
+ return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
202
+ }
203
+
204
+ template<class T>
205
+ typename List<T>::value_type List<T>::extract(size_type pos) const {
206
+ auto& elem = impl_->list.at(pos);
207
+ auto result = c10::detail::list_element_to<T>(std::move(elem));
208
+ // Reset the list element to a T() instead of None to keep it correctly typed
209
+ elem = c10::detail::ListElementFrom<T>::from(T{});
210
+ return result;
211
+ }
212
+
213
+ template<class T>
214
+ typename List<T>::iterator List<T>::begin() const {
215
+ return iterator(impl_->list.begin());
216
+ }
217
+
218
+ template<class T>
219
+ typename List<T>::iterator List<T>::end() const {
220
+ return iterator(impl_->list.end());
221
+ }
222
+
223
+ template<class T>
224
+ bool List<T>::empty() const {
225
+ return impl_->list.empty();
226
+ }
227
+
228
+ template<class T>
229
+ typename List<T>::size_type List<T>::size() const {
230
+ return impl_->list.size();
231
+ }
232
+
233
+ template<class T>
234
+ void List<T>::reserve(size_type new_cap) const {
235
+ impl_->list.reserve(new_cap);
236
+ }
237
+
238
+ template<class T>
239
+ void List<T>::clear() const {
240
+ impl_->list.clear();
241
+ }
242
+
243
+ template<class T>
244
+ typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
245
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
246
+ }
247
+
248
+ template<class T>
249
+ typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
250
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
251
+ }
252
+
253
+ template<class T>
254
+ template<class... Args>
255
+ typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
256
+ // TODO Use list_element_from?
257
+ return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
258
+ }
259
+
260
+ template<class T>
261
+ void List<T>::push_back(const T& value) const {
262
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
263
+ }
264
+
265
+ template<class T>
266
+ void List<T>::push_back(T&& value) const {
267
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
268
+ }
269
+
270
+ template<class T>
271
+ void List<T>::append(List<T> b) const {
272
+ if (b.use_count() == 1) {
273
+ impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
274
+ } else {
275
+ impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
276
+ }
277
+ }
278
+
279
+ template<class T>
280
+ template<class... Args>
281
+ void List<T>::emplace_back(Args&&... args) const {
282
+ // TODO Use list_element_from?
283
+ impl_->list.push_back(T(std::forward<Args>(args)...));
284
+ }
285
+
286
+ template<class T>
287
+ typename List<T>::iterator List<T>::erase(iterator pos) const {
288
+ return iterator { impl_->list.erase(pos.iterator_) };
289
+ }
290
+
291
+ template<class T>
292
+ typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
293
+ return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
294
+ }
295
+
296
+ template<class T>
297
+ void List<T>::pop_back() const {
298
+ impl_->list.pop_back();
299
+ }
300
+
301
+ template<class T>
302
+ void List<T>::resize(size_type count) const {
303
+ impl_->list.resize(count, T{});
304
+ }
305
+
306
+ template<class T>
307
+ void List<T>::resize(size_type count, const T& value) const {
308
+ impl_->list.resize(count, value);
309
+ }
310
+
311
+ template<class T>
312
+ bool operator==(const List<T>& lhs, const List<T>& rhs) {
313
+ // Lists with the same identity trivially compare equal.
314
+ if (lhs.impl_ == rhs.impl_) {
315
+ return true;
316
+ }
317
+
318
+ // Otherwise, just compare values directly.
319
+ return *lhs.impl_ == *rhs.impl_;
320
+ }
321
+
322
+ template<class T>
323
+ bool operator!=(const List<T>& lhs, const List<T>& rhs) {
324
+ return !(lhs == rhs);
325
+ }
326
+
327
+ template<class T>
328
+ bool List<T>::is(const List<T>& rhs) const {
329
+ return this->impl_ == rhs.impl_;
330
+ }
331
+
332
+ template<class T>
333
+ std::vector<T> List<T>::vec() const {
334
+ std::vector<T> result(begin(), end());
335
+ return result;
336
+ }
337
+
338
+ template<class T>
339
+ size_t List<T>::use_count() const {
340
+ return impl_.use_count();
341
+ }
342
+
343
+ template <class T>
344
+ TypePtr List<T>::elementType() const {
345
+ return impl_->elementType;
346
+ }
347
+
348
+ template <class T>
349
+ void List<T>::unsafeSetElementType(TypePtr t) {
350
+ impl_->elementType = std::move(t);
351
+ }
352
+
353
+ }
phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+
5
+ // define constants like M_PI and C keywords for MSVC
6
+ #ifdef _MSC_VER
7
+ #ifndef _USE_MATH_DEFINES
8
+ #define _USE_MATH_DEFINES
9
+ #endif
10
+ #include <math.h>
11
+ #endif
12
+
13
+ #include <array>
14
+ #include <cmath>
15
+ #include <cstdint>
16
+
17
+ namespace at {
18
+
19
+ constexpr int MERSENNE_STATE_N = 624;
20
+ constexpr int MERSENNE_STATE_M = 397;
21
+ constexpr uint32_t MATRIX_A = 0x9908b0df;
22
+ constexpr uint32_t UMASK = 0x80000000;
23
+ constexpr uint32_t LMASK = 0x7fffffff;
24
+
25
+ /**
26
+ * Note [Mt19937 Engine implementation]
27
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28
+ * Originally implemented in:
29
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
30
+ * and modified with C++ constructs. Moreover the state array of the engine
31
+ * has been modified to hold 32 bit uints instead of 64 bits.
32
+ *
33
+ * Note that we reimplemented mt19937 instead of using std::mt19937 because,
34
+ * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
35
+ * by default and following are the benchmark numbers (benchmark code can be found at
36
+ * https://github.com/syed-ahmed/benchmark-rngs):
37
+ *
38
+ * with -O2
39
+ * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
40
+ * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
41
+ * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
42
+ * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
43
+ *
44
+ * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
45
+ * however we can't use std::uniform_real_distribution because of this bug:
46
+ * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
47
+ * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
48
+ * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
49
+ * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
50
+ * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
51
+ *
52
+ * Copyright notice:
53
+ * A C-program for MT19937, with initialization improved 2002/2/10.
54
+ * Coded by Takuji Nishimura and Makoto Matsumoto.
55
+ * This is a faster version by taking Shawn Cokus's optimization,
56
+ * Matthe Bellew's simplification, Isaku Wada's real version.
57
+ *
58
+ * Before using, initialize the state by using init_genrand(seed)
59
+ * or init_by_array(init_key, key_length).
60
+ *
61
+ * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
62
+ * All rights reserved.
63
+ *
64
+ * Redistribution and use in source and binary forms, with or without
65
+ * modification, are permitted provided that the following conditions
66
+ * are met:
67
+ *
68
+ * 1. Redistributions of source code must retain the above copyright
69
+ * notice, this list of conditions and the following disclaimer.
70
+ *
71
+ * 2. Redistributions in binary form must reproduce the above copyright
72
+ * notice, this list of conditions and the following disclaimer in the
73
+ * documentation and/or other materials provided with the distribution.
74
+ *
75
+ * 3. The names of its contributors may not be used to endorse or promote
76
+ * products derived from this software without specific prior written
77
+ * permission.
78
+ *
79
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
80
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
81
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
82
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
83
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
84
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
85
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
86
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
87
+ * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
88
+ * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
89
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
90
+ *
91
+ *
92
+ * Any feedback is very welcome.
93
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
94
+ * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
95
+ */
96
+
97
+ /**
98
+ * mt19937_data_pod is used to get POD data in and out
99
+ * of mt19937_engine. Used in torch.get_rng_state and
100
+ * torch.set_rng_state functions.
101
+ */
102
+ struct mt19937_data_pod {
103
+ uint64_t seed_;
104
+ int left_;
105
+ bool seeded_;
106
+ uint32_t next_;
107
+ std::array<uint32_t, MERSENNE_STATE_N> state_;
108
+ };
109
+
110
+ class mt19937_engine {
111
+ public:
112
+
113
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
114
+ inline explicit mt19937_engine(uint64_t seed = 5489) {
115
+ init_with_uint32(seed);
116
+ }
117
+
118
+ inline mt19937_data_pod data() const {
119
+ return data_;
120
+ }
121
+
122
+ inline void set_data(const mt19937_data_pod& data) {
123
+ data_ = data;
124
+ }
125
+
126
+ inline uint64_t seed() const {
127
+ return data_.seed_;
128
+ }
129
+
130
+ inline bool is_valid() {
131
+ if ((data_.seeded_ == true)
132
+ && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
133
+ && (data_.next_ <= MERSENNE_STATE_N)) {
134
+ return true;
135
+ }
136
+ return false;
137
+ }
138
+
139
+ inline uint32_t operator()() {
140
+ if (--(data_.left_) == 0) {
141
+ next_state();
142
+ }
143
+ uint32_t y = *(data_.state_.data() + data_.next_++);
144
+ y ^= (y >> 11);
145
+ y ^= (y << 7) & 0x9d2c5680;
146
+ y ^= (y << 15) & 0xefc60000;
147
+ y ^= (y >> 18);
148
+
149
+ return y;
150
+ }
151
+
152
+ private:
153
+ mt19937_data_pod data_;
154
+
155
+ inline void init_with_uint32(uint64_t seed) {
156
+ data_.seed_ = seed;
157
+ data_.seeded_ = true;
158
+ data_.state_[0] = seed & 0xffffffff;
159
+ for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
160
+ data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
161
+ }
162
+ data_.left_ = 1;
163
+ data_.next_ = 0;
164
+ }
165
+
166
+ inline uint32_t mix_bits(uint32_t u, uint32_t v) {
167
+ return (u & UMASK) | (v & LMASK);
168
+ }
169
+
170
+ inline uint32_t twist(uint32_t u, uint32_t v) {
171
+ return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
172
+ }
173
+
174
+ inline void next_state() {
175
+ uint32_t* p = data_.state_.data();
176
+ data_.left_ = MERSENNE_STATE_N;
177
+ data_.next_ = 0;
178
+
179
+ for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
180
+ *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
181
+ }
182
+
183
+ for(int j = MERSENNE_STATE_M; --j; p++) {
184
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
185
+ }
186
+
187
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
188
+ }
189
+
190
+ };
191
+
192
+ typedef mt19937_engine mt19937;
193
+
194
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Dimname.h>
4
+ #include <c10/core/TensorImpl.h>
5
+
6
+ namespace at {
7
+
8
+ class TensorBase;
9
+
10
+ // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
11
+ // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
12
+ // so we have a couple of workarounds.
13
+ //
14
+ // In the long term, we'll move Dimname to c10 and everything in this file
15
+ // can be refactored out. The main blocker for that is that "c10::Symbol"
16
+ // actually exists outside of c10 and needs to be moved in.
17
+
18
+ // TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
19
+ // XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl.
20
+ //
21
+ // This class has an important invariant: there must be at least ONE
22
+ // non-wildcard
23
+ struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
24
+ // This enum is to remind people that the invariant on constructors is that
25
+ // the list of dimnames must have at least one non-wildcard
26
+ enum HAS_NON_WILDCARD {
27
+ HasNonWildcard
28
+ };
29
+
30
+ explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
31
+ : names_(names.vec()) {
32
+ check_invariants();
33
+ }
34
+ explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
35
+ : names_(std::move(names)) {
36
+ check_invariants();
37
+ }
38
+
39
+ std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
40
+ return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
41
+ }
42
+
43
+ DimnameList names() const { return names_; }
44
+
45
+ // Used for an assertion in TensorImpl.h
46
+ int64_t slow_dim() const override {
47
+ return static_cast<int64_t>(names_.size());
48
+ }
49
+
50
+ void check_invariants() const {
51
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
52
+ std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
53
+ }
54
+
55
+ void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
56
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
57
+ std::copy(new_names.begin(), new_names.end(), names_.begin());
58
+ check_invariants();
59
+ }
60
+
61
+ void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
62
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
63
+ names_ = std::move(new_names);
64
+ check_invariants();
65
+ }
66
+
67
+ // INVARIANT: at least one Dimname is non-WILDCARD
68
+ std::vector<Dimname> names_;
69
+ };
70
+
71
+ // When NamesMode is disabled, then all operations ignore tensors' names fields.
72
+ // Concretely speaking, all tensors are treated as having nullopt names.
73
+ struct TORCH_API NamesMode {
74
+ static bool is_enabled();
75
+ static void set_enabled(bool enabled);
76
+ };
77
+
78
+
79
+ // A RAII, thread local (!) guard that enables or disables names upon
80
+ // construction, and sets it back to the original value upon destruction.
81
+ struct TORCH_API NoNamesGuard {
82
+ NoNamesGuard() : prev_mode(NamesMode::is_enabled()) {
83
+ NamesMode::set_enabled(false);
84
+ }
85
+ NoNamesGuard(const NoNamesGuard&) = delete;
86
+ NoNamesGuard(NoNamesGuard&&) = delete;
87
+ NoNamesGuard& operator=(const NoNamesGuard&) = delete;
88
+ NoNamesGuard& operator=(NoNamesGuard&&) = delete;
89
+ ~NoNamesGuard() {
90
+ if (initialized) {
91
+ reset();
92
+ }
93
+ }
94
+ void reset() {
95
+ TORCH_INTERNAL_ASSERT(initialized);
96
+ NamesMode::set_enabled(prev_mode);
97
+ }
98
+ private:
99
+ bool prev_mode;
100
+ bool initialized{true};
101
+ };
102
+
103
+ void check_names_valid_for(const TensorBase& tensor, DimnameList names);
104
+ void check_names_valid_for(size_t tensor_dim, DimnameList names);
105
+
106
+ // Sets the names of `tensor` to be `names`.
107
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names);
108
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
109
+
110
+ constexpr size_t kMaxNamedTensorDim = 64;
111
+
112
+ DimnameList default_names(size_t len);
113
+
114
+ namespace impl {
115
+
116
+ // Some helper functions on TensorImpl. Useful for working with names in TH.
117
+ // XXX: Ideally these would exist as methods on TensorImpl
118
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names);
119
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
120
+
121
+ void check_names_valid_for(TensorImpl* impl, DimnameList names);
122
+
123
+ // Returns true if the tensor's names exist and are not all 'None'.
124
+ // Returns false if the tensor's names don't exist (were not allocated),
125
+ // or if all names are 'None'.
126
+ // We treat not-allocated-names the same as allocated names that are all 'None'.
127
+ TORCH_API bool has_names(const TensorImpl* impl);
128
+
129
+ // Returns the names of the tensor's dimensions.
130
+ // Unnamed tensors are treated as having 'None' in all dimension; this method
131
+ // would return a DimnameList of all 'None's for an unnamed tensor.
132
+ TORCH_API DimnameList get_names(const TensorImpl* impl);
133
+
134
+ // This is more of an implementation detail; one should use impl::get_names /
135
+ // Tensor::names() whenever possible because it provides a cleaner API.
136
+ // Returns the names of the tensor if they have been allocated; returns nullopt
137
+ // instead if the haven't been. The names of a tensor are not allocated if a
138
+ // tensor is constructed with names=None.
139
+ TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl);
140
+
141
+ } // namespace impl
142
+
143
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ConstantSymNodeImpl.h>
4
+ #include <c10/core/SymNodeImpl.h>
5
+ #include <c10/macros/Export.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/intrusive_ptr.h>
8
+ #include <cstdint>
9
+ #include <optional>
10
+ #include <string>
11
+
12
+ namespace c10 {
13
+
14
+ // The motivating usecase for this is to represent the ragged size structure
15
+ // of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
16
+ // allows us to simply return [B, j0, D] if someone queries for the size of our
17
+ // tensor.
18
+ //
19
+ // Morally we define comparison between two nested ints to return true if
20
+ // that comparison holds for all corresponding elements of the arrays they
21
+ // represent. Comparison between a nested int and a plain int is defined
22
+ // similarly.
23
+ //
24
+ // To simulate this desired behavior but also avoid the O(N) cost of checking,
25
+ // we associate each raggedness pattern with an integer "id" that can be used as
26
+ // a proxy to evaluate equality. We also constrain the range of values for this
27
+ // as to enable inequality checks.
28
+ //
29
+ // We also support a positive integer scalar "coeff" that is used for computing
30
+ // strides. For example given, a [B, j0, D] tensor, it can be strided in two
31
+ // different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
32
+ // differentiate the two cases.
33
+ //
34
+ // During tracing the strides of the outputs need to be a function of the size
35
+ // and strides of the inputs so it is important that NestedIntSymNode itself is
36
+ // able to express this.
37
+ class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
38
+ public:
39
+ // CAUTION: you should probably not be constructing these directly; please
40
+ // the higher-level API in python instead (TODO: actually introduce that).
41
+ explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
42
+ : val_(val), coeff_(coeff) {}
43
+
44
+ bool bool_() override {
45
+ return false;
46
+ }
47
+
48
+ bool is_int() override {
49
+ return true;
50
+ }
51
+
52
+ bool is_float() override {
53
+ return false;
54
+ }
55
+
56
+ bool is_bool() override {
57
+ return false;
58
+ }
59
+
60
+ bool is_nested_int() const override {
61
+ return true;
62
+ }
63
+
64
+ bool has_hint() override {
65
+ return true;
66
+ }
67
+
68
+ c10::SymNode wrap_int(int64_t num) override {
69
+ return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
70
+ }
71
+
72
+ int64_t guard_int(const char* file, int64_t line) override {
73
+ TORCH_CHECK(false);
74
+ }
75
+
76
+ double guard_float(const char* file, int64_t line) override {
77
+ TORCH_CHECK(false, "not a float");
78
+ }
79
+
80
+ bool guard_bool(const char* file, int64_t line) override {
81
+ TORCH_CHECK(false, "not a bool");
82
+ }
83
+
84
+ int64_t int_() override {
85
+ TORCH_CHECK(false);
86
+ }
87
+
88
+ std::string str() override {
89
+ if (coeff_ == 1) {
90
+ return "j" + std::to_string(val_);
91
+ }
92
+ return std::to_string(coeff_) + "*j" + std::to_string(val_);
93
+ }
94
+
95
+ // NOTE [ Inequalities with nested int ]
96
+ //
97
+ // The semantics of nested int when it comes to relations is that it is
98
+ // treated as integer known to be within a certain range,
99
+ //
100
+ // j0 \in [2, int64_t::max]
101
+ //
102
+ // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
103
+ // This is a useful default range for the raggedness pattern of a jagged
104
+ // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
105
+ // specialization checks.
106
+ //
107
+ // [ Indeterminate inequalities error out ]
108
+ //
109
+ // Given the semantic defined above, certain relations like j0 < 3 are thus
110
+ // indeterminable. In our impl today, evaluating such relations error
111
+ //
112
+ // It may seem convenient to just define indeterminate relations to return
113
+ // False, but the implementation we maintain in parallel using sympy does not
114
+ // allow this.
115
+ //
116
+ // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
117
+ // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
118
+ // would mean that means that if we define the indeterminate j0 >= 3 to be
119
+ // False, the also indeterminate j0 < 3 will be evaluated to be True!
120
+ //
121
+ // [ Coefficient are assumed positive ]
122
+ //
123
+ // For the purpose of computing inequalities, we consider the coefficient of
124
+ // the nested int to be a positive integer.
125
+ //
126
+ // Thus, no modifications are needed to the logic since
127
+ // j0 >= k implies coeff * j0 >= k
128
+ //
129
+ c10::SymNode eq(const c10::SymNode& other) override;
130
+ c10::SymNode ne(const c10::SymNode& other) override;
131
+ c10::SymNode ge(const c10::SymNode& other) override;
132
+ c10::SymNode gt(const c10::SymNode& other) override;
133
+ c10::SymNode lt(const c10::SymNode& other) override;
134
+ c10::SymNode le(const c10::SymNode& other) override;
135
+ c10::SymNode mul(const c10::SymNode& other) override;
136
+
137
+ std::optional<int64_t> nested_int() override {
138
+ return val_;
139
+ }
140
+
141
+ std::optional<int64_t> nested_int_coeff() override {
142
+ return coeff_;
143
+ }
144
+
145
+ bool is_symbolic() override {
146
+ return false;
147
+ }
148
+
149
+ c10::SymNode clone() override;
150
+
151
+ #define DEFINE_BINARY_NOT_SUPPORTED(name) \
152
+ c10::SymNode name(const c10::SymNode& other) override { \
153
+ TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
154
+ }
155
+
156
+ DEFINE_BINARY_NOT_SUPPORTED(add)
157
+ DEFINE_BINARY_NOT_SUPPORTED(sub)
158
+ DEFINE_BINARY_NOT_SUPPORTED(truediv)
159
+ DEFINE_BINARY_NOT_SUPPORTED(pow)
160
+ DEFINE_BINARY_NOT_SUPPORTED(floordiv)
161
+ DEFINE_BINARY_NOT_SUPPORTED(mod)
162
+ DEFINE_BINARY_NOT_SUPPORTED(sym_min)
163
+ DEFINE_BINARY_NOT_SUPPORTED(sym_max)
164
+ DEFINE_BINARY_NOT_SUPPORTED(sym_and)
165
+ DEFINE_BINARY_NOT_SUPPORTED(sym_or)
166
+
167
+ #undef DEFINE_BINARY_NOT_SUPPORTED
168
+
169
+ #define DEFINE_NOT_SUPPORTED(name) \
170
+ c10::SymNode name() override { \
171
+ TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
172
+ }
173
+
174
+ DEFINE_NOT_SUPPORTED(sym_not)
175
+ DEFINE_NOT_SUPPORTED(ceil)
176
+ DEFINE_NOT_SUPPORTED(floor)
177
+ DEFINE_NOT_SUPPORTED(neg)
178
+ DEFINE_NOT_SUPPORTED(sym_float)
179
+
180
+ #undef DEFINE_NOT_SUPPORTED
181
+
182
+ private:
183
+ int64_t val_;
184
+ int64_t coeff_;
185
+ };
186
+
187
+ } // namespace c10
phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // define constants like M_PI and C keywords for MSVC
4
+ #ifdef _MSC_VER
5
+ #define _USE_MATH_DEFINES
6
+ #include <math.h>
7
+ #endif
8
+
9
+
10
+ #ifdef __CUDACC__
11
+ #include <cuda.h>
12
+ #endif
13
+
14
+ #include <array>
15
+ #include <c10/macros/Macros.h>
16
+ #include <cmath>
17
+ #include <cstdint>
18
+
19
+ namespace at {
20
+
21
+ // typedefs for holding vector data
22
+ namespace detail {
23
+
24
+ typedef std::array<uint32_t, 4> UINT4;
25
+ typedef std::array<uint32_t, 2> UINT2;
26
+ typedef std::array<double, 2> DOUBLE2;
27
+ typedef std::array<float, 2> FLOAT2;
28
+
29
+ } // namespace detail
30
+
31
+ /**
32
+ * Note [Philox Engine implementation]
33
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
34
+ * Originally implemented in PyTorch's fusion compiler
35
+ * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
36
+ * for details regarding the engine.
37
+ *
38
+ * Note that currently this implementation of the philox engine is not used
39
+ * anywhere except for tests in cpu_generator_test.cpp. However, this engine
40
+ * will replace curandStatePhilox4_32_10_t in the future.
41
+ *
42
+ * The philox engine takes a seed value, a subsequeunce
43
+ * for starting the generation and an offset for the subsequence.
44
+ * Think of this engine as an algorithm producing a huge array. We are
45
+ * parallelizing this array by partitioning the huge array and assigning
46
+ * a thread index to each partition. In other words, each seed value
47
+ * (there are 2^64 possible seed values) gives a sub array of size
48
+ * 2^128 (each element in that array is a 128 bit number). Reasoning
49
+ * behind the array being of size 2^128 is, there are 2^64 possible
50
+ * thread index value and there is an array of size 2^64 for each of
51
+ * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
52
+ *
53
+ * In short, this generator can produce 2^64 (seed values) * 2^128 (number
54
+ * of elements in an array given by a seed value) = 2^192 values.
55
+ *
56
+ * Arguments:
57
+ * seed: Seed values could be any number from 0 to 2^64-1.
58
+ * subsequence: Subsequence is just the cuda thread indexing with:
59
+ * - blockIdx.x * blockDim.x + threadIdx.x
60
+ * offset: The offset variable in PhiloxEngine decides how many 128-bit
61
+ * random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
62
+ * and hence really decides the total number of randoms that can be achieved
63
+ * for the given subsequence.
64
+ */
65
+
66
+ class philox_engine {
67
+ public:
68
+
69
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
70
+ C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
71
+ uint64_t subsequence = 0,
72
+ uint64_t offset = 0) {
73
+
74
+ reset_state(seed, subsequence);
75
+ incr_n(offset);
76
+ }
77
+
78
+ C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
79
+ uint64_t subsequence = 0) {
80
+ key_[0] = static_cast<uint32_t>(seed);
81
+ key_[1] = static_cast<uint32_t>(seed >> 32);
82
+ counter_ = detail::UINT4{};
83
+ counter_[2] = static_cast<uint32_t>(subsequence);
84
+ counter_[3] = static_cast<uint32_t>(subsequence >> 32);
85
+ STATE = 0;
86
+ }
87
+
88
+ /**
89
+ * Set the offset field of Philox Generator to the desired offset.
90
+ */
91
+ C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
92
+ counter_[0] = static_cast<uint32_t>(offset);
93
+ counter_[1] = static_cast<uint32_t>(offset >> 32);
94
+ }
95
+
96
+ /**
97
+ * Gets the current offset of the Philox Generator.
98
+ */
99
+ C10_HOST_DEVICE uint64_t get_offset() const {
100
+ uint64_t lo = static_cast<uint64_t>(counter_[0]);
101
+ uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
102
+ return lo | hi;
103
+ }
104
+
105
+ /**
106
+ * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
107
+ */
108
+ C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
109
+ if(STATE == 0) {
110
+ detail::UINT4 counter = counter_;
111
+ detail::UINT2 key = key_;
112
+ output_ = rand(counter, key, n_rounds);
113
+ incr();
114
+ }
115
+ uint32_t ret = output_[static_cast<int>(STATE)];
116
+ STATE = (STATE + 1) & 3;
117
+ return ret;
118
+ }
119
+
120
+ inline float randn(uint32_t n_rounds) {
121
+ #ifdef __CUDA_ARCH__
122
+ AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
123
+ #endif
124
+ if(STATE == 0) {
125
+ detail::UINT4 counter = counter_;
126
+ detail::UINT2 key = key_;
127
+ output_ = rand(counter, key, n_rounds);
128
+ incr();
129
+ }
130
+ // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
131
+ // TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
132
+ float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
133
+ float u2 = 1 - uint32_to_uniform_float(output_[1]);
134
+ return static_cast<float>(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
135
+ }
136
+
137
+ /**
138
+ * Function that Skips N 128 bit numbers in a subsequence
139
+ */
140
+ C10_HOST_DEVICE inline void incr_n(uint64_t n) {
141
+ uint32_t nlo = static_cast<uint32_t>(n);
142
+ uint32_t nhi = static_cast<uint32_t>(n >> 32);
143
+ counter_[0] += nlo;
144
+ // if overflow in x has occurred, carry over to nhi
145
+ if (counter_[0] < nlo) {
146
+ nhi++;
147
+ // if overflow in nhi has occurred during carry over,
148
+ // propagate that overflow to y and exit to increment z
149
+ // otherwise return
150
+ counter_[1] += nhi;
151
+ if(nhi != 0) {
152
+ if (nhi <= counter_[1]) {
153
+ return;
154
+ }
155
+ }
156
+ } else {
157
+ // if overflow in y has occurred during addition,
158
+ // exit to increment z
159
+ // otherwise return
160
+ counter_[1] += nhi;
161
+ if (nhi <= counter_[1]) {
162
+ return;
163
+ }
164
+ }
165
+ if (++counter_[2])
166
+ return;
167
+ ++counter_[3];
168
+ }
169
+
170
+ /**
171
+ * Function that Skips one 128 bit number in a subsequence
172
+ */
173
+ C10_HOST_DEVICE inline void incr() {
174
+ if (++counter_[0])
175
+ return;
176
+ if (++counter_[1])
177
+ return;
178
+ if (++counter_[2]) {
179
+ return;
180
+ }
181
+ ++counter_[3];
182
+ }
183
+
184
+ private:
185
+ detail::UINT4 counter_;
186
+ detail::UINT4 output_;
187
+ detail::UINT2 key_;
188
+ uint32_t STATE;
189
+
190
+ C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
191
+ uint32_t *result_high) {
192
+ #ifdef __CUDA_ARCH__
193
+ *result_high = __umulhi(a, b);
194
+ return a*b;
195
+ #else
196
+ const uint64_t product = static_cast<uint64_t>(a) * b;
197
+ *result_high = static_cast<uint32_t>(product >> 32);
198
+ return static_cast<uint32_t>(product);
199
+ #endif
200
+ }
201
+
202
+ C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
203
+ uint32_t hi0 = 0;
204
+ uint32_t hi1 = 0;
205
+ uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
206
+ uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
207
+ detail::UINT4 ret;
208
+ ret[0] = hi1 ^ ctr[1] ^ in_key[0];
209
+ ret[1] = lo1;
210
+ ret[2] = hi0 ^ ctr[3] ^ in_key[1];
211
+ ret[3] = lo0;
212
+ return ret;
213
+ }
214
+
215
+ C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
216
+ // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
217
+ constexpr float scale = 4.6566127342e-10;
218
+ return static_cast<float>(value & 0x7FFFFFFF) * scale;
219
+ }
220
+
221
+
222
+
223
+ C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
224
+ for (uint32_t round = 0; round < (n_rounds - 1); round++) {
225
+ counter = single_round(counter, key);
226
+ key[0] += (kPhilox10A); key[1] += (kPhilox10B);
227
+ }
228
+ return single_round(counter, key);
229
+ }
230
+
231
+
232
+ static const uint32_t kPhilox10A = 0x9E3779B9;
233
+ static const uint32_t kPhilox10B = 0xBB67AE85;
234
+ static const uint32_t kPhiloxSA = 0xD2511F53;
235
+ static const uint32_t kPhiloxSB = 0xCD9E8D57;
236
+ };
237
+
238
+ typedef philox_engine Philox4_32;
239
+
240
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TorchDispatchUtils.h>
3
+
4
+
5
+ namespace at::impl {
6
+
7
+ struct TORCH_API RestorePythonTLSSnapshot {
8
+ RestorePythonTLSSnapshot();
9
+ RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete;
10
+ RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete;
11
+ RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete;
12
+ RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete;
13
+ ~RestorePythonTLSSnapshot();
14
+
15
+ private:
16
+ c10::impl::LocalDispatchKeySet saved_;
17
+ c10::impl::ForceDispatchKeyGuard guard_;
18
+ };
19
+
20
+
21
+ // RAII guard to make working with the above TLS safer.
22
+ struct TORCH_API MaybeSetTLSOnEntryGuard {
23
+ public:
24
+ MaybeSetTLSOnEntryGuard();
25
+ MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete;
26
+ MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete;
27
+ MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete;
28
+ MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete;
29
+ ~MaybeSetTLSOnEntryGuard();
30
+
31
+ private:
32
+ bool value_set_;
33
+ };
34
+
35
+ } // namespace at::impl
phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/dispatch/Dispatcher.h>
4
+
5
+ // TODO: this can probably live in c10
6
+
7
+
8
+ namespace at::impl {
9
+
10
+ class TORCH_API PythonOpRegistrationTrampoline final {
11
+ static std::atomic<c10::impl::PyInterpreter*> interpreter_;
12
+
13
+ public:
14
+ // Returns true if you successfully registered yourself (that means
15
+ // you are in the hot seat for doing the operator registrations!)
16
+ static bool registerInterpreter(c10::impl::PyInterpreter*);
17
+
18
+ // Returns nullptr if no interpreter has been registered yet.
19
+ static c10::impl::PyInterpreter* getInterpreter();
20
+ };
21
+
22
+ } // namespace at::impl
phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+ #include <c10/core/QScheme.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace at {
8
+
9
+ class Tensor;
10
+ struct QTensorImpl;
11
+ struct Quantizer;
12
+ using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
13
+ using QuantizerPtr = c10::intrusive_ptr<Quantizer>;
14
+
15
+ /**
16
+ * Quantizer is the class for storing all the information
17
+ * that's necessary to perform quantize and dequantize
18
+ * operation.
19
+ *
20
+ * We might have different types of quantization schemes and this is
21
+ * the base class for all quantizers.
22
+ *
23
+ * QTensorImpl will hold a pointer to Quantizer so that we can support
24
+ * different quantization schemes on Tensor.
25
+ *
26
+ * For example, the most common quantization scheme, Affine Quantization,
27
+ * requires scale and zero_point as parameters, we'll store scale and zero_point
28
+ * inside the instance and we can use it to quantize a float Tensor or
29
+ * dequantize a quantized Tensor.
30
+ *
31
+ * When you add new types of leaf Quantizer class, please also
32
+ * make sure to add a corresponding QScheme enum since
33
+ * they should have one to one mapping.
34
+ *
35
+ * Note about intrusive_ptr:
36
+ * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
37
+ * share the same Quantizer. Quantizer should be immutable.
38
+ */
39
+ struct TORCH_API Quantizer : public c10::intrusive_ptr_target {
40
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
41
+ const ScalarType scalar_type_;
42
+ explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {}
43
+ ~Quantizer() override = default;
44
+
45
+ // Copied from torch/csrc/jit/ir/scope.h
46
+ QuantizerPtr intrusive_from_this() {
47
+ c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
48
+ // from a raw `this` pointer
49
+ // so we need to bump the refcount
50
+ // to account for this ownership
51
+ return c10::intrusive_ptr<Quantizer>::reclaim(this);
52
+ }
53
+
54
+ /**
55
+ * Each concrete Quantizer type should have a unique QScheme type.
56
+ */
57
+ virtual QScheme qscheme() const = 0;
58
+
59
+ ScalarType scalar_type() const {
60
+ return scalar_type_;
61
+ }
62
+
63
+ /**
64
+ * quantize a float Tensor into a quantized Tensor.
65
+ */
66
+ virtual Tensor quantize(const Tensor& t) = 0;
67
+
68
+ /**
69
+ * dequantize a quantized Tensor into a float Tensor.
70
+ */
71
+ virtual Tensor dequantize(const Tensor& t) = 0;
72
+
73
+ /**
74
+ * dequantize a quantized Tensor into a float Tensor, out= variant
75
+ */
76
+ virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0;
77
+
78
+ /**
79
+ * Compare against `other` for equality.
80
+ */
81
+ virtual bool equalTo(QuantizerPtr other) const = 0;
82
+ };
83
+
84
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/Range.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <iosfwd>
5
+
6
+ namespace at {
7
+
8
+ struct Range {
9
+ Range(int64_t begin, int64_t end)
10
+ : begin(begin)
11
+ , end(end) {}
12
+
13
+ int64_t size() const { return end - begin; }
14
+
15
+ Range operator/(int64_t divisor) {
16
+ return Range(begin / divisor, end / divisor);
17
+ }
18
+
19
+ int64_t begin;
20
+ int64_t end;
21
+ };
22
+
23
+ std::ostream& operator<<(std::ostream& out, const Range& range);
24
+
25
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::Reduction {
4
+
5
+ // NB: Keep this in sync with Reduction class in torch/nn/_reduction.py
6
+ // These constants control the reduction behavior of loss functions.
7
+ // Ideally, this would be a scoped enum, but jit doesn't support that
8
+ enum Reduction {
9
+ None, // Do not reduce
10
+ Mean, // (Possibly weighted) mean of losses
11
+ Sum, // Sum losses
12
+ END
13
+ };
14
+ } // namespace at::Reduction
phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/Scalar.h>
phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/ScalarType.h>
phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TensorBody.h>
4
+ #include <c10/util/Exception.h>
5
+
6
+ namespace at {
7
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
8
+ class TORCH_API OptionalTensorRef {
9
+ public:
10
+ OptionalTensorRef() = default;
11
+
12
+ ~OptionalTensorRef() {
13
+ ref_.unsafeReleaseTensorImpl();
14
+ }
15
+
16
+ OptionalTensorRef(const TensorBase& src)
17
+ : ref_(Tensor::unsafe_borrow_t{}, src) {
18
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
19
+ }
20
+
21
+ OptionalTensorRef(const OptionalTensorRef& rhs)
22
+ : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
23
+
24
+ OptionalTensorRef(OptionalTensorRef&& rhs) = default;
25
+ OptionalTensorRef& operator=(OptionalTensorRef rhs) {
26
+ std::swap(ref_, rhs.ref_);
27
+ return *this;
28
+ }
29
+
30
+ bool has_value() const {
31
+ return ref_.defined();
32
+ }
33
+
34
+ const Tensor& getTensorRef() const & {
35
+ return ref_;
36
+ }
37
+
38
+ const Tensor& operator*() const & {
39
+ return ref_;
40
+ }
41
+
42
+ const Tensor* operator->() const & {
43
+ return &ref_;
44
+ }
45
+
46
+ operator bool() const {
47
+ return ref_.defined();
48
+ }
49
+
50
+ private:
51
+ Tensor ref_;
52
+ };
53
+
54
+ // Use to convert a TensorBase (that may be undefined) to an at::Tensor
55
+ // without bumping refcount.
56
+ class TORCH_API TensorRef {
57
+ public:
58
+ ~TensorRef() {
59
+ ref_.unsafeReleaseTensorImpl();
60
+ }
61
+
62
+ TensorRef(const TensorBase& src)
63
+ : ref_(Tensor::unsafe_borrow_t{}, src) {}
64
+ TensorRef(TensorRef&& other) = default;
65
+ TensorRef(const TensorRef&) = default;
66
+ TensorRef& operator=(const TensorRef&) = default;
67
+ TensorRef& operator=(TensorRef&&) = default;
68
+
69
+ const Tensor& operator*() const & {
70
+ return ref_;
71
+ }
72
+ private:
73
+ Tensor ref_;
74
+ };
75
+
76
+ template <typename T>
77
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
78
+ // Return the grad argument in case of a hook with void return type to have an
79
+ // std::function with Tensor return type
80
+ static_assert(std::is_same_v<decltype(hook(Tensor())), void>,
81
+ "Expected hook to return void");
82
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
83
+ TensorRef grad(grad_base);
84
+ fn(*grad);
85
+ return Tensor();
86
+ });
87
+ }
88
+
89
+ template <typename T>
90
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
91
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
92
+ TensorRef grad(grad_base);
93
+ Tensor ret = fn(*grad);
94
+ return TensorBase(std::move(ret));
95
+ });
96
+ }
97
+
98
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <c10/util/Deprecated.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/irange.h>
8
+ #include <cstddef>
9
+ #include <cstdint>
10
+ #include <type_traits>
11
+
12
+ namespace at {
13
+
14
+ // The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
15
+ // is used to enable the __restrict__ keyword/modifier for the data
16
+ // passed to cuda.
17
+ template <typename T>
18
+ struct DefaultPtrTraits {
19
+ typedef T* PtrType;
20
+ };
21
+
22
+ #if defined(__CUDACC__) || defined(__HIPCC__)
23
+ template <typename T>
24
+ struct RestrictPtrTraits {
25
+ typedef T* __restrict__ PtrType;
26
+ };
27
+ #endif
28
+
29
+ // TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
30
+ // For CUDA tensors it is used in device code (only). This means that we restrict ourselves
31
+ // to functions and types available there (e.g. IntArrayRef isn't).
32
+
33
+ // The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
34
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
35
+ class TensorAccessorBase {
36
+ public:
37
+ typedef typename PtrTraits<T>::PtrType PtrType;
38
+
39
+ C10_HOST_DEVICE TensorAccessorBase(
40
+ PtrType data_,
41
+ const index_t* sizes_,
42
+ const index_t* strides_)
43
+ : data_(data_), sizes_(sizes_), strides_(strides_) {}
44
+ C10_HOST IntArrayRef sizes() const {
45
+ return IntArrayRef(sizes_,N);
46
+ }
47
+ C10_HOST IntArrayRef strides() const {
48
+ return IntArrayRef(strides_,N);
49
+ }
50
+ C10_HOST_DEVICE index_t stride(index_t i) const {
51
+ return strides_[i];
52
+ }
53
+ C10_HOST_DEVICE index_t size(index_t i) const {
54
+ return sizes_[i];
55
+ }
56
+ C10_HOST_DEVICE PtrType data() {
57
+ return data_;
58
+ }
59
+ C10_HOST_DEVICE const PtrType data() const {
60
+ return data_;
61
+ }
62
+ protected:
63
+ PtrType data_;
64
+ const index_t* sizes_;
65
+ const index_t* strides_;
66
+ };
67
+
68
+ // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
69
+ // `Tensor.accessor<T, N>()`.
70
+ // For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
71
+ // indexing on the device uses `TensorAccessor`s.
72
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
73
+ class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
74
+ public:
75
+ typedef typename PtrTraits<T>::PtrType PtrType;
76
+
77
+ C10_HOST_DEVICE TensorAccessor(
78
+ PtrType data_,
79
+ const index_t* sizes_,
80
+ const index_t* strides_)
81
+ : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
82
+
83
+ C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
84
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
85
+ }
86
+
87
+ C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
88
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
89
+ }
90
+ };
91
+
92
+ template<typename T, template <typename U> class PtrTraits, typename index_t>
93
+ class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
94
+ public:
95
+ typedef typename PtrTraits<T>::PtrType PtrType;
96
+
97
+ C10_HOST_DEVICE TensorAccessor(
98
+ PtrType data_,
99
+ const index_t* sizes_,
100
+ const index_t* strides_)
101
+ : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
102
+ C10_HOST_DEVICE T & operator[](index_t i) {
103
+ // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
104
+ return this->data_[this->strides_[0]*i];
105
+ }
106
+ C10_HOST_DEVICE const T & operator[](index_t i) const {
107
+ return this->data_[this->strides_[0]*i];
108
+ }
109
+ };
110
+
111
+
112
+ // GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
113
+ // and as
114
+ // In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
115
+ // in order to transfer them on the device when calling kernels.
116
+ // On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
117
+ // Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
118
+ // Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
119
+ // on the device, so those functions are host only.
120
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
121
+ class GenericPackedTensorAccessorBase {
122
+ public:
123
+ typedef typename PtrTraits<T>::PtrType PtrType;
124
+ C10_HOST GenericPackedTensorAccessorBase(
125
+ PtrType data_,
126
+ const index_t* sizes_,
127
+ const index_t* strides_)
128
+ : data_(data_) {
129
+ std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
130
+ std::copy(strides_, strides_ + N, std::begin(this->strides_));
131
+ }
132
+
133
+ // if index_t is not int64_t, we want to have an int64_t constructor
134
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
135
+ C10_HOST GenericPackedTensorAccessorBase(
136
+ PtrType data_,
137
+ const source_index_t* sizes_,
138
+ const source_index_t* strides_)
139
+ : data_(data_) {
140
+ for (const auto i : c10::irange(N)) {
141
+ this->sizes_[i] = sizes_[i];
142
+ this->strides_[i] = strides_[i];
143
+ }
144
+ }
145
+
146
+ C10_HOST_DEVICE index_t stride(index_t i) const {
147
+ return strides_[i];
148
+ }
149
+ C10_HOST_DEVICE index_t size(index_t i) const {
150
+ return sizes_[i];
151
+ }
152
+ C10_HOST_DEVICE PtrType data() {
153
+ return data_;
154
+ }
155
+ C10_HOST_DEVICE const PtrType data() const {
156
+ return data_;
157
+ }
158
+ protected:
159
+ PtrType data_;
160
+ // NOLINTNEXTLINE(*c-arrays*)
161
+ index_t sizes_[N];
162
+ // NOLINTNEXTLINE(*c-arrays*)
163
+ index_t strides_[N];
164
+ C10_HOST void bounds_check_(index_t i) const {
165
+ TORCH_CHECK_INDEX(
166
+ 0 <= i && i < index_t{N},
167
+ "Index ",
168
+ i,
169
+ " is not within bounds of a tensor of dimension ",
170
+ N);
171
+ }
172
+ };
173
+
174
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
175
+ class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
176
+ public:
177
+ typedef typename PtrTraits<T>::PtrType PtrType;
178
+
179
+ C10_HOST GenericPackedTensorAccessor(
180
+ PtrType data_,
181
+ const index_t* sizes_,
182
+ const index_t* strides_)
183
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
184
+
185
+ // if index_t is not int64_t, we want to have an int64_t constructor
186
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
187
+ C10_HOST GenericPackedTensorAccessor(
188
+ PtrType data_,
189
+ const source_index_t* sizes_,
190
+ const source_index_t* strides_)
191
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
192
+
193
+ C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
194
+ index_t* new_sizes = this->sizes_ + 1;
195
+ index_t* new_strides = this->strides_ + 1;
196
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
197
+ }
198
+
199
+ C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
200
+ const index_t* new_sizes = this->sizes_ + 1;
201
+ const index_t* new_strides = this->strides_ + 1;
202
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
203
+ }
204
+
205
+ /// Returns a PackedTensorAccessor of the same dimension after transposing the
206
+ /// two dimensions given. Does not actually move elements; transposition is
207
+ /// made by permuting the size/stride arrays. If the dimensions are not valid,
208
+ /// asserts.
209
+ C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
210
+ index_t dim1,
211
+ index_t dim2) const {
212
+ this->bounds_check_(dim1);
213
+ this->bounds_check_(dim2);
214
+ GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
215
+ this->data_, this->sizes_, this->strides_);
216
+ std::swap(result.strides_[dim1], result.strides_[dim2]);
217
+ std::swap(result.sizes_[dim1], result.sizes_[dim2]);
218
+ return result;
219
+ }
220
+ };
221
+
222
+ template<typename T, template <typename U> class PtrTraits, typename index_t>
223
+ class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
224
+ public:
225
+ typedef typename PtrTraits<T>::PtrType PtrType;
226
+ C10_HOST GenericPackedTensorAccessor(
227
+ PtrType data_,
228
+ const index_t* sizes_,
229
+ const index_t* strides_)
230
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
231
+
232
+ // if index_t is not int64_t, we want to have an int64_t constructor
233
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
234
+ C10_HOST GenericPackedTensorAccessor(
235
+ PtrType data_,
236
+ const source_index_t* sizes_,
237
+ const source_index_t* strides_)
238
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
239
+
240
+ C10_DEVICE T & operator[](index_t i) {
241
+ return this->data_[this->strides_[0] * i];
242
+ }
243
+ C10_DEVICE const T& operator[](index_t i) const {
244
+ return this->data_[this->strides_[0]*i];
245
+ }
246
+
247
+ // Same as in the general N-dimensional case, but note that in the
248
+ // 1-dimensional case the returned PackedTensorAccessor will always be an
249
+ // identical copy of the original
250
+ C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
251
+ index_t dim1,
252
+ index_t dim2) const {
253
+ this->bounds_check_(dim1);
254
+ this->bounds_check_(dim2);
255
+ return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
256
+ this->data_, this->sizes_, this->strides_);
257
+ }
258
+ };
259
+
260
+
261
+ // Can't put this directly into the macro function args because of commas
262
+ #define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
263
+
264
+ // Old name for `GenericPackedTensorAccessor`
265
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
266
+ C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)
267
+
268
+ #undef AT_X
269
+
270
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
271
+ using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;
272
+
273
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
274
+ using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
275
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Device.h>
4
+ #include <c10/core/Layout.h>
5
+ #include <c10/core/MemoryFormat.h>
6
+ #include <c10/core/ScalarType.h>
7
+ #include <c10/core/ScalarTypeToTypeMeta.h>
8
+ #include <c10/core/Storage.h>
9
+ #include <c10/core/SymIntArrayRef.h>
10
+ #include <c10/core/TensorImpl.h>
11
+ #include <c10/core/TensorOptions.h>
12
+ #include <c10/core/UndefinedTensorImpl.h>
13
+ #include <c10/core/WrapDimMinimal.h>
14
+ #include <c10/util/C++17.h>
15
+ #include <c10/util/Exception.h>
16
+ #include <c10/util/ExclusivelyOwned.h>
17
+ #include <c10/util/ExclusivelyOwnedTensorTraits.h>
18
+ #include <c10/util/MaybeOwned.h>
19
+ #include <optional>
20
+ #include <c10/util/intrusive_ptr.h>
21
+
22
+ #include <ATen/core/NamedTensor.h>
23
+ #include <ATen/core/QuantizerBase.h>
24
+ #include <ATen/core/TensorAccessor.h>
25
+ #include <ATen/StorageUtils.h>
26
+
27
+ namespace c10 {
28
+ class Scalar;
29
+ }
30
+
31
+ namespace torch::autograd {
32
+
33
+ struct Node;
34
+
35
+ } // namespace torch::autograd
36
+
37
+ namespace at {
38
+
39
+ class Tensor;
40
+ class TensorBase;
41
+
42
+ // Convert Tensor to TensorBase without any need to include Tensor.h
43
+ TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
44
+
45
+ namespace impl {
46
+ inline bool variable_excluded_from_dispatch() {
47
+ #ifdef C10_MOBILE
48
+ // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
49
+ return true;
50
+ #else
51
+ return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
52
+ #endif
53
+ }
54
+
55
+ }
56
+
57
+ // NOTE: [Tensor vs. TensorBase]
58
+ //
59
+ // Tensor, being the central data structure in PyTorch, gets used and
60
+ // its header included almost everywhere. Unfortunately this means
61
+ // every time an operator signature is updated or changed in
62
+ // native_functions.yaml, you (and every other PyTorch developer) need
63
+ // to recompile all of ATen and its dependencies.
64
+ //
65
+ // TensorBase aims to break up these header dependencies, and improve
66
+ // incremental build times for all PyTorch developers. TensorBase
67
+ // represents a reference counted handle to TensorImpl, exactly the
68
+ // same as Tensor. However, TensorBase doesn't have code generated
69
+ // methods in its API and thus no dependence on native_functions.yaml.
70
+ //
71
+ // Usage tips
72
+ // ----------
73
+ // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
74
+ // or .cu file to ensure it has no header dependencies on
75
+ // native_functions.yaml (direct or indirect).
76
+ // - Tensor inherits from TensorBase, so functions taking
77
+ // `const TensorBase &` are callable with Tensor as well.
78
+ // - TensorBase can be converted to Tensor with `Tensor(tensor_base)`,
79
+ // but this requires a reference-count bump. OptionalTensorRef, on
80
+ // the other hand, can materialize a `const Tensor &` without
81
+ // touching the reference-count.
82
+ class TORCH_API TensorBase {
83
+ public:
84
+ struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
85
+
86
+ protected:
87
+ // Create a Tensor with a +0 reference count. Special care must be
88
+ // taken to avoid decrementing this reference count at destruction
89
+ // time. Intended to support MaybeOwnedTraits<Tensor>.
90
+ explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
91
+ : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {}
92
+ friend MaybeOwnedTraits<TensorBase>;
93
+
94
+ public:
95
+ TensorBase() = default;
96
+ // This constructor should not be used by end users and is an implementation
97
+ // detail invoked by autogenerated code.
98
+ explicit TensorBase(
99
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
100
+ : impl_(std::move(tensor_impl)) {
101
+ if (impl_.get() == nullptr) {
102
+ throw std::runtime_error("TensorImpl with nullptr is not supported");
103
+ }
104
+ }
105
+ TensorBase(const TensorBase&) = default;
106
+ TensorBase(TensorBase&&) noexcept = default;
107
+ ~TensorBase() noexcept = default;
108
+
109
+ public:
110
+ // Creates a new wrapper from TensorImpl. Intentionally a free method because
111
+ // it should be used with care. Checks necessary invariants
112
+ static TensorBase wrap_tensor_impl(
113
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
114
+ TensorBase r(std::move(tensor_impl));
115
+ r.enforce_invariants();
116
+ return r;
117
+ }
118
+
119
+ int64_t dim() const {
120
+ return impl_->dim();
121
+ }
122
+ int64_t storage_offset() const {
123
+ return impl_->storage_offset();
124
+ }
125
+
126
+ TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
127
+ if (is_contiguous(memory_format)) {
128
+ return *this;
129
+ } else {
130
+ return __dispatch_contiguous(memory_format);
131
+ }
132
+ }
133
+
134
+ /// Should be used if *this can reasonably be expected to be contiguous and
135
+ /// performance is important.
136
+ /// Compared to contiguous, it saves a reference count
137
+ /// increment/decrement if *this is already contiguous, at the cost
138
+ /// in all cases of an extra pointer of stack usage, an extra branch
139
+ /// to access, and an extra branch at destruction time.
140
+ c10::MaybeOwned<TensorBase> expect_contiguous(
141
+ MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
142
+
143
+ // Use .contiguous() instead. Trying to borrow from a prvalue
144
+ // will only lead to trouble and dangling references.
145
+ c10::MaybeOwned<TensorBase> expect_contiguous(
146
+ MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
147
+
148
+ const TensorBase& fill_(const c10::Scalar& scalar) const;
149
+ const TensorBase& zero_() const;
150
+
151
+ TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
152
+
153
+ bool is_complex() const {
154
+ return at::isComplexType(this->scalar_type());
155
+ }
156
+
157
+ bool is_floating_point() const {
158
+ return at::isFloatingType(this->scalar_type());
159
+ }
160
+
161
+ bool is_signed() const {
162
+ return at::isSignedType(this->scalar_type());
163
+ }
164
+
165
+ c10::SymInt sym_size(int64_t dim) const {
166
+ return impl_->sym_size(dim);
167
+ }
168
+
169
+ c10::SymInt sym_stride(int64_t dim) const {
170
+ const auto sizes = this->sym_strides();
171
+ const auto ndim = static_cast<int64_t>(sizes.size());
172
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
173
+ return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
174
+
175
+ }
176
+
177
+ int64_t size(int64_t dim) const {
178
+ return impl_->size(dim);
179
+ }
180
+
181
+ int64_t stride(int64_t dim) const {
182
+ const auto strides = this->strides();
183
+ const auto ndim = static_cast<int64_t>(strides.size());
184
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
185
+ return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
186
+ }
187
+
188
+ TensorImpl * unsafeGetTensorImpl() const {
189
+ return impl_.get();
190
+ }
191
+ TensorImpl * unsafeReleaseTensorImpl() {
192
+ return impl_.release();
193
+ }
194
+ const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
195
+ return impl_;
196
+ }
197
+
198
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
199
+ return std::move(impl_);
200
+ }
201
+
202
+ bool defined() const {
203
+ return impl_;
204
+ }
205
+
206
+ void reset() {
207
+ impl_.reset();
208
+ }
209
+
210
+ #if defined (_MSC_VER)
211
+ TensorBase& operator=(const TensorBase& x) & {
212
+ impl_ = x.impl_;
213
+ return *this;
214
+ };
215
+ TensorBase& operator=(TensorBase&& x) & noexcept {
216
+ impl_ = std::move(x.impl_);
217
+ return *this;
218
+ }
219
+ #else
220
+ TensorBase& operator=(const TensorBase& x) & = default;
221
+ TensorBase& operator=(TensorBase&& x) & noexcept = default;
222
+ #endif
223
+
224
+ // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
225
+ TensorBase& operator=(const TensorBase&) && = delete;
226
+ TensorBase& operator=(TensorBase&&) && noexcept = delete;
227
+
228
+ bool is_same(const TensorBase& other) const noexcept {
229
+ return impl_ == other.impl_;
230
+ }
231
+ size_t use_count() const noexcept {
232
+ return impl_.use_count();
233
+ }
234
+ size_t weak_use_count() const noexcept {
235
+ return impl_.weak_use_count();
236
+ }
237
+
238
+ std::string toString() const;
239
+
240
+ IntArrayRef sizes() const {
241
+ return impl_->sizes();
242
+ }
243
+ c10::SymIntArrayRef sym_sizes() const {
244
+ return impl_->sym_sizes();
245
+ }
246
+ c10::SymIntArrayRef sym_strides() const {
247
+ return impl_->sym_strides();
248
+ }
249
+ IntArrayRef strides() const {
250
+ return impl_->strides();
251
+ }
252
+ // See impl::get_opt_names in ATen/NamedTensor.h for docs.
253
+ std::optional<DimnameList> opt_names() const {
254
+ return impl::get_opt_names(unsafeGetTensorImpl());
255
+ }
256
+ // See impl::get_names in ATen/NamedTensor.h for docs.
257
+ DimnameList names() const {
258
+ return impl::get_names(unsafeGetTensorImpl());
259
+ }
260
+ int64_t ndimension() const {
261
+ return dim();
262
+ }
263
+
264
+ bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
265
+ return impl_->is_contiguous(memory_format);
266
+ }
267
+
268
+ bool is_non_overlapping_and_dense() const {
269
+ return impl_->is_non_overlapping_and_dense();
270
+ }
271
+
272
+ at::MemoryFormat suggest_memory_format(
273
+ bool channels_last_strides_exact_match = false) const {
274
+ // Setting channels_last_strides_exact_match to true forces function to
275
+ // check 0,1 - sized dimension strides.
276
+ if (layout() == at::kStrided) {
277
+ if (impl_->is_strides_like_channels_last()) {
278
+ if (!channels_last_strides_exact_match ||
279
+ get_channels_last_strides_2d(sizes()) == strides()) {
280
+ return at::MemoryFormat::ChannelsLast;
281
+ }
282
+ }
283
+ else if (impl_->is_strides_like_channels_last_3d()) {
284
+ if (!channels_last_strides_exact_match ||
285
+ get_channels_last_strides_3d(sizes()) == strides()) {
286
+ return at::MemoryFormat::ChannelsLast3d;
287
+ }
288
+ }
289
+ }
290
+ return at::MemoryFormat::Contiguous;
291
+ }
292
+
293
+ // Total bytes consumed by the "view" of elements of the array. Does not
294
+ // include size of metadata. The number reported here does not necessarily
295
+ // correspond to the true physical memory consumed by a tensor; instead,
296
+ // it reports the memory the tensor would take *if* it were contiguous.
297
+ // Defined to be numel() * itemsize()
298
+ size_t nbytes() const {
299
+ TORCH_CHECK(layout () != at::kSparse,
300
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
301
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
302
+ "equivalent dense tensor, multiply numel() by element_size()");
303
+ return impl_->numel() * impl_->itemsize();
304
+ }
305
+
306
+ c10::SymInt sym_nbytes() const {
307
+ TORCH_CHECK(layout () != at::kSparse,
308
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
309
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
310
+ "equivalent dense tensor, multiply numel() by element_size()");
311
+ return impl_->sym_numel() * impl_->itemsize();
312
+ }
313
+
314
+ int64_t numel() const {
315
+ return impl_->numel();
316
+ }
317
+
318
+ c10::SymInt sym_numel() const {
319
+ return impl_->sym_numel();
320
+ }
321
+
322
+ c10::SymInt sym_storage_offset() const {
323
+ return impl_->sym_storage_offset();
324
+ }
325
+
326
+ // Length of one array element in bytes. This is the traditional
327
+ // Numpy naming.
328
+ size_t itemsize() const {
329
+ return impl_->itemsize();
330
+ }
331
+
332
+ // Same as itemsize(). This is the PyTorch naming.
333
+ int64_t element_size() const {
334
+ return static_cast<int64_t>(impl_->itemsize());
335
+ }
336
+
337
+ DispatchKeySet key_set() const {
338
+ return impl_->key_set();
339
+ }
340
+ ScalarType scalar_type() const {
341
+ return typeMetaToScalarType(impl_->dtype());
342
+ }
343
+ bool has_storage() const {
344
+ return defined() && impl_->has_storage();
345
+ }
346
+ const Storage& storage() const {
347
+ return impl_->storage();
348
+ }
349
+ bool is_alias_of(const at::TensorBase& other) const{
350
+ return impl_->storage().is_alias_of(other.storage());
351
+ }
352
+
353
+ // Move the storage backend to shm based
354
+ // to enable memory sharing across processes.
355
+ //
356
+ // NB1: the ideal behavior of this API still requires further discussion
357
+ // but for now we are inclined to keep it consistent with existing THP behavior
358
+ // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
359
+ // so we don't assert on anything here and rely on caller knowing
360
+ // what it's doing.
361
+ //
362
+ // NB2: this currently provides Linux fd based shm support only
363
+ // to simplify the storage lifetime management logic in ATen
364
+ // and similarly for now we are not adding support for file system based
365
+ // shm support like in THP due to additional GC manager support needed
366
+ // to prevent leaks.
367
+ // As such, calling this from non supported systems (e.g. Windows) would fail.
368
+ void share_memory_() {
369
+ at::share_memory_(*this);
370
+ }
371
+
372
+ inline bool _is_zerotensor() const {
373
+ return impl_->_is_zerotensor();
374
+ }
375
+
376
+ inline void _set_zero(bool zero) const {
377
+ impl_->_set_zero(zero);
378
+ }
379
+
380
+ inline bool is_conj() const {
381
+ return impl_->is_conj();
382
+ }
383
+
384
+ // sets the conjugate bit of a tensor.
385
+ // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
386
+ // that's what you want. Changing this might lead to incorrect behavior since conjugation is
387
+ // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
388
+ inline void _set_conj(bool conjugate) const {
389
+ impl_->_set_conj(conjugate);
390
+ }
391
+
392
+ inline bool is_neg() const {
393
+ return impl_->is_neg();
394
+ }
395
+
396
+ // sets the negative bit of a tensor.
397
+ // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
398
+ // that's what you want. Changing this might lead to incorrect behavior since we rely on this
399
+ // bit to determine if a negation needs to be materialized.
400
+ inline void _set_neg(bool negative) const {
401
+ impl_->_set_neg(negative);
402
+ }
403
+
404
+ /// Returns a `Tensor`'s layout.
405
+ Layout layout() const {
406
+ return impl_->layout();
407
+ }
408
+
409
+ /// Returns a `Tensor`'s dtype (`TypeMeta`).
410
+ caffe2::TypeMeta dtype() const {
411
+ return impl_->dtype();
412
+ }
413
+
414
+ /// Returns a `Tensor`'s device.
415
+ inline Device device() const {
416
+ return impl_->device();
417
+ }
418
+
419
+ /// Returns a `Tensor`'s device index.
420
+ DeviceIndex get_device() const {
421
+ // NB: this is not a native function to avoid dispatching overhead.
422
+ return impl_->get_device();
423
+ }
424
+
425
+ /// Returns if a `Tensor` has CPU backend.
426
+ bool is_cpu() const {
427
+ // NB: this is not a native function to avoid dispatching overhead.
428
+ return impl_->is_cpu();
429
+ }
430
+
431
+ /// Returns if a `Tensor` has CUDA backend.
432
+ bool is_cuda() const {
433
+ // NB: this is not a native function to avoid dispatching overhead.
434
+ return impl_->is_cuda();
435
+ }
436
+
437
+ /// Returns if a `Tensor` has IPU backend.
438
+ bool is_ipu() const {
439
+ // NB: this is not a native function to avoid dispatching overhead.
440
+ return impl_->is_ipu();
441
+ }
442
+
443
+ /// Returns if a `Tensor` has XPU backend.
444
+ bool is_xpu() const {
445
+ // NB: this is not a native function to avoid dispatching overhead.
446
+ return impl_->is_xpu();
447
+ }
448
+
449
+ /// Returns if a `Tensor` has XLA backend.
450
+ bool is_xla() const {
451
+ return impl_->is_xla();
452
+ }
453
+
454
+ /// Returns if a `Tensor` has MTIA backend.
455
+ bool is_mtia() const {
456
+ return impl_->is_mtia();
457
+ }
458
+
459
+ /// Returns if a `Tensor` has HPU backend.
460
+ bool is_hpu() const {
461
+ return impl_->is_hpu();
462
+ }
463
+
464
+ /// Returns if a `Tensor` has Lazy backend.
465
+ bool is_lazy() const {
466
+ return impl_->is_lazy();
467
+ }
468
+
469
+ /// Returns if a `Tensor` has HIP backend.
470
+ bool is_hip() const {
471
+ // NB: this is not a native function to avoid dispatching overhead.
472
+ return impl_->is_hip();
473
+ }
474
+
475
+ /// Returns if a `Tensor` has VE backend.
476
+ bool is_ve() const {
477
+ // NB: this is not a native function to avoid dispatching overhead.
478
+ return impl_->is_ve();
479
+ }
480
+
481
+ /// Returns if a `Tensor` has PrivateUse1 backend.
482
+ bool is_privateuseone() const {
483
+ // NB: this is not a native function to avoid dispatching overhead.
484
+ return impl_->is_privateuseone();
485
+ }
486
+
487
+ /// Returns if a `Tensor` has sparse backend.
488
+ bool is_sparse() const {
489
+ // NB: this is not a native function to avoid dispatching overhead.
490
+ return impl_->is_sparse();
491
+ }
492
+
493
+ /// Returns is a `Tensor` has a sparse CSR backend.
494
+ bool is_sparse_csr() const {
495
+ // NB: this is not a native function to avoid dispatching overhead.
496
+ return impl_->is_sparse_csr();
497
+ }
498
+
499
+ /// Returns if a `Tensor` is mkldnn tensor.
500
+ bool is_mkldnn() const {
501
+ // NB: this is not a native function to avoid dispatching overhead.
502
+ return impl_->is_mkldnn();
503
+ }
504
+
505
+ /// Returns if a `Tensor` is mps tensor.
506
+ bool is_mps() const {
507
+ // NB: this is not a native function to avoid dispatching overhead.
508
+ return impl_->is_mps();
509
+ }
510
+
511
+ /// Returns if a `Tensor` is maia tensor.
512
+ bool is_maia() const {
513
+ // NB: this is not a native function to avoid dispatching overhead.
514
+ return impl_->is_maia();
515
+ }
516
+
517
+ /// Returns if a `Tensor` is vulkan tensor.
518
+ bool is_vulkan() const {
519
+ // NB: this is not a native function to avoid dispatching overhead.
520
+ return impl_->is_vulkan();
521
+ }
522
+
523
+ /// Returns if a `Tensor` is metal tensor.
524
+ bool is_metal() const {
525
+ // NB: this is not a native function to avoid dispatching overhead.
526
+ return impl_->is_metal();
527
+ }
528
+
529
+ /// Returns if a `Tensor` has quantized backend.
530
+ bool is_quantized() const {
531
+ // NB: this is not a native function to avoid dispatching overhead.
532
+ return impl_->is_quantized();
533
+ }
534
+
535
+ /// Returns if a `Tensor` is a meta tensor. Meta tensors can
536
+ /// also have other designations.
537
+ bool is_meta() const {
538
+ return impl_->is_meta();
539
+ }
540
+
541
+ /// Returns if a `Tensor` is an inference tensor.
542
+ bool is_inference() const {
543
+ return impl_->is_inference();
544
+ }
545
+
546
+ // Returns if a `Tensor` is a NestedTensor.
547
+ bool is_nested() const {
548
+ return impl_->is_nested();
549
+ }
550
+
551
+ /// If a tensor is a quantized tensor, returns its quantizer
552
+ /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
553
+ QuantizerPtr quantizer() const;
554
+
555
+ /// Returns if a `Tensor` has any dimension names
556
+ bool has_names() const {
557
+ // If a user is using unnamed tensors, then we can short-circuit right here.
558
+ // Otherwise, impl::has_names attempts to retrieve names.
559
+ if (!impl_->has_named_tensor_meta()) {
560
+ return false;
561
+ }
562
+ return impl::has_names(unsafeGetTensorImpl());
563
+ }
564
+
565
+ /// Returns a `Tensor`'s dimension names data structure
566
+ const NamedTensorMeta* get_named_tensor_meta() const {
567
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
568
+ }
569
+
570
+ NamedTensorMeta* get_named_tensor_meta() {
571
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
572
+ }
573
+
574
+ /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
575
+ /// TensorOptions.h.
576
+ TensorOptions options() const {
577
+ return TensorOptions().dtype(dtype())
578
+ .device(device())
579
+ .layout(layout());
580
+ }
581
+
582
+ const void* const_data_ptr() const {
583
+ return this->unsafeGetTensorImpl()->data();
584
+ }
585
+
586
+ void* mutable_data_ptr() const {
587
+ return this->unsafeGetTensorImpl()->mutable_data();
588
+ }
589
+
590
+ // TODO(#97856) Make this return a const pointer. This currently
591
+ // returns a non-const pointer because of the large
592
+ // number of clients that we still want to audit before
593
+ // migrating to mutable_data_ptr().
594
+ void* data_ptr() const {
595
+ return mutable_data_ptr();
596
+ }
597
+
598
+ template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
599
+ const T* const_data_ptr() const;
600
+
601
+ template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
602
+ const std::remove_const_t<T>* const_data_ptr() const;
603
+
604
+ template <typename T>
605
+ T* mutable_data_ptr() const;
606
+
607
+ // Legacy interface during the migration to indicate that a callsite
608
+ // has not been audited for mutability.
609
+ //
610
+ // Do not add new uses of this, use const_data_ptr() if possible,
611
+ // mutable_data_ptr() otherwise.
612
+ //
613
+ // TODO(#97856) Make this return a const pointer. This is currently
614
+ // const because of the vast number of clients that
615
+ // rely on this.
616
+ template <typename T>
617
+ T* data_ptr() const;
618
+
619
+ // Purposely not defined here to avoid inlining
620
+ void print() const;
621
+
622
+ // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
623
+ // dimension.
624
+ template<typename T, size_t N>
625
+ TensorAccessor<T,N> accessor() const& {
626
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
627
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
628
+ T* ptr = nullptr;
629
+ if constexpr (std::is_const_v<T>) {
630
+ ptr = const_data_ptr<T>();
631
+ } else {
632
+ ptr = mutable_data_ptr<T>();
633
+ }
634
+ return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
635
+ }
636
+ template<typename T, size_t N>
637
+ TensorAccessor<T,N> accessor() && = delete;
638
+
639
+ // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
640
+ // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
641
+ // cast the data pointer to a __restrict__ pointer.
642
+ // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
643
+ // as an argument.
644
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
645
+ GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
646
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
647
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
648
+ T* ptr = nullptr;
649
+ if constexpr (std::is_const_v<T>) {
650
+ ptr = const_data_ptr<T>();
651
+ } else {
652
+ ptr = mutable_data_ptr<T>();
653
+ }
654
+ return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
655
+ }
656
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
657
+ GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
658
+
659
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
660
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
661
+ TORCH_CHECK(
662
+ impl_->numel() <=
663
+ static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
664
+ "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
665
+ return generic_packed_accessor<T,N,PtrTraits,int32_t>();
666
+ }
667
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
668
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
669
+
670
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
671
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
672
+ return generic_packed_accessor<T,N,PtrTraits,int64_t>();
673
+ }
674
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
675
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
676
+
677
+ // ~~~~~ Autograd API ~~~~~
678
+
679
+ /// \fn bool is_leaf() const;
680
+ ///
681
+ /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
682
+ ///
683
+ /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
684
+ /// created by the user. This means that they are not the result of an operation and so
685
+ /// `grad_fn()` is `nullptr`.
686
+ ///
687
+ /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
688
+ /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
689
+ ///
690
+ /// Example:
691
+ /// @code
692
+ /// auto a = torch::rand(10, torch::requires_grad());
693
+ /// std::cout << a.is_leaf() << std::endl; // prints `true`
694
+ ///
695
+ /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
696
+ /// std::cout << b.is_leaf() << std::endl; // prints `false`
697
+ /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
698
+ ///
699
+ /// auto c = torch::rand(10, torch::requires_grad()) + 2;
700
+ /// std::cout << c.is_leaf() << std::endl; // prints `false`
701
+ /// // c was created by the addition operation
702
+ ///
703
+ /// auto d = torch::rand(10).cuda();
704
+ /// std::cout << d.is_leaf() << std::endl; // prints `true`
705
+ /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
706
+ ///
707
+ /// auto e = torch::rand(10).cuda().requires_grad_();
708
+ /// std::cout << e.is_leaf() << std::endl; // prints `true`
709
+ /// // e requires gradients and has no operations creating it
710
+ ///
711
+ /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
712
+ /// std::cout << f.is_leaf() << std::endl; // prints `true`
713
+ /// // f requires grad, has no operation creating it
714
+ /// @endcode
715
+
716
+ /// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
717
+ ///
718
+ /// Computes the gradient of current tensor with respect to graph leaves.
719
+ ///
720
+ /// The graph is differentiated using the chain rule. If the tensor is
721
+ /// non-scalar (i.e. its data has more than one element) and requires
722
+ /// gradient, the function additionally requires specifying ``gradient``.
723
+ /// It should be a tensor of matching type and location, that contains
724
+ /// the gradient of the differentiated function w.r.t. this Tensor.
725
+ ///
726
+ /// This function accumulates gradients in the leaves - you might need to
727
+ /// zero them before calling it.
728
+ ///
729
+ /// \param gradient Gradient w.r.t. the
730
+ /// tensor. If it is a tensor, it will be automatically converted
731
+ /// to a Tensor that does not require grad unless ``create_graph`` is True.
732
+ /// None values can be specified for scalar Tensors or ones that
733
+ /// don't require grad. If a None value would be acceptable then
734
+ /// this argument is optional.
735
+ /// \param retain_graph If ``false``, the graph used to compute
736
+ /// the grads will be freed. Note that in nearly all cases setting
737
+ /// this option to True is not needed and often can be worked around
738
+ /// in a much more efficient way. Defaults to the value of
739
+ /// ``create_graph``.
740
+ /// \param create_graph If ``true``, graph of the derivative will
741
+ /// be constructed, allowing to compute higher order derivative
742
+ /// products. Defaults to ``false``.
743
+ /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
744
+ /// ``at::Tensor::grad``. All other Tensors will be ignored. If not
745
+ /// provided, the gradient is accumulated into all the leaf Tensors
746
+ /// that were used to compute the current tensor.
747
+ /// When inputs are provided and a given input is not a leaf,
748
+ /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
749
+ /// It is an implementation detail on which the user should not rely.
750
+ /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
751
+
752
+ /// \fn Tensor detach() const;
753
+ ///
754
+ /// Returns a new Tensor, detached from the current graph.
755
+ /// The result will never require gradient.
756
+
757
+ /// \fn Tensor & detach_() const;
758
+ ///
759
+ /// Detaches the Tensor from the graph that created it, making it a leaf.
760
+ /// Views cannot be detached in-place.
761
+
762
+ /// \fn void retain_grad() const;
763
+ ///
764
+ /// Enables this Tensor to have their :attr:`grad` populated during
765
+ /// :func:`backward`. This is a no-op for leaf tensors.
766
+
767
+ /// \fn bool retains_grad() const;
768
+ ///
769
+ /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
770
+ /// populated during :func:`backward`, ``false`` otherwise.
771
+
772
+ const TensorBase& set_requires_grad(bool requires_grad) const {
773
+ impl_->set_requires_grad(requires_grad);
774
+ return *this;
775
+ }
776
+ bool requires_grad() const {
777
+ return impl_->requires_grad();
778
+ }
779
+
780
+ // The Forward AD API functions below are low level and are not to be used by end
781
+ // users who should use the API provided in torch/csrc/autograd.h
782
+
783
+ /// This function returns the forward gradient for this Tensor at the given level.
784
+ const Tensor& _fw_grad(uint64_t level) const {
785
+ return impl_->_fw_grad(level, *this);
786
+ }
787
+
788
+ /// This function can be used to set the value of the forward grad.
789
+ /// Note that the given new_grad might not be used directly if it has different
790
+ /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
791
+ /// new_grad content will be copied into a new Tensor
792
+ void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
793
+ impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
794
+ }
795
+
796
+ /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
797
+ /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
798
+ /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
799
+ ///
800
+ /// One notable difference with the legacy `.data()` function is that changes to the
801
+ /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
802
+ /// will not update the original `Variable`, due to the fact that this function
803
+ /// shallow-copies the `Variable`'s underlying TensorImpl.
804
+ at::TensorBase tensor_data() const;
805
+
806
+ /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
807
+ /// in Python, which create a new `Variable` that shares the same storage and
808
+ /// tensor metadata with the original `Variable`, but with a completely new
809
+ /// autograd history.
810
+ ///
811
+ /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
812
+ /// storage / storage_offset) of a variable created from `var.variable_data()`, those
813
+ /// changes will not update the original variable `var`. In `.variable_data()`, we set
814
+ /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
815
+ /// in order to prevent users from changing metadata of `var.variable_data()`
816
+ /// and expecting the original variable `var` to also be updated.
817
+ at::TensorBase variable_data() const;
818
+
819
+ // Gradient Node and Edges
820
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
821
+
822
+ /// Gets the gradient function of the `Variable`. If this is a leaf variable,
823
+ /// the pointer returned will be null.
824
+ ///
825
+ /// For View Variables:
826
+ /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
827
+ /// re-create the grad_fn to express the up-to-date view relationship between
828
+ /// this and the base Variable.
829
+ const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
830
+
831
+ // Hooks
832
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
833
+
834
+ template <typename T>
835
+ using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
836
+ template <typename T>
837
+ using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
838
+
839
+ /// Registers a backward hook.
840
+ ///
841
+ /// The hook will be called every time a gradient with respect to the Tensor is computed.
842
+ /// The hook should have one of the following signature:
843
+ /// ```
844
+ /// hook(TensorBase grad) -> TensorBase
845
+ /// ```
846
+ /// ```
847
+ /// hook(TensorBase grad) -> void
848
+ /// ```
849
+ /// The hook should not modify its argument, but it can optionally return a new gradient
850
+ /// which will be used in place of `grad`.
851
+ ///
852
+ /// This function returns the index of the hook in the list which can be used to remove hook.
853
+ ///
854
+ /// Example:
855
+ /// @code
856
+ /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
857
+ /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
858
+ /// v.backward(torch::tensor({1., 2., 3.}));
859
+ /// // This prints:
860
+ /// // ```
861
+ /// // 2
862
+ /// // 4
863
+ /// // 6
864
+ /// // [ CPUFloatType{3} ]
865
+ /// // ```
866
+ /// std::cout << v.grad() << std::endl;
867
+ /// v.remove_hook(h); // removes the hook
868
+ /// @endcode
869
+ template <typename T>
870
+ hook_return_void_t<T> register_hook(T&& hook) const;
871
+ template <typename T>
872
+ hook_return_var_t<T> register_hook(T&& hook) const;
873
+
874
+ protected:
875
+ unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
876
+
877
+ public:
878
+
879
+ /// Remove hook at given position
880
+ void remove_hook(unsigned pos) const;
881
+
882
+ // Variable methods
883
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
884
+
885
+ bool is_leaf() const;
886
+
887
+ int64_t output_nr() const;
888
+
889
+ void set_data(const TensorBase & new_data) const;
890
+
891
+ TensorBase data() const;
892
+
893
+ int64_t _version() const;
894
+
895
+ void retain_grad() const;
896
+
897
+ bool retains_grad() const;
898
+
899
+ const TensorBase& requires_grad_(bool _requires_grad=true) const;
900
+
901
+ // View Variables
902
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
903
+
904
+ /// Returns true if this `Variable` is a view of another `Variable`.
905
+ bool is_view() const;
906
+
907
+ /// Returns the `Variable` that this `Variable` is a view of. If this
908
+ /// `Variable` is not a view, throw a `std::runtime_error`.
909
+ const TensorBase& _base() const;
910
+
911
+ // Miscellaneous
912
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
913
+
914
+ const std::string& name() const;
915
+
916
+ protected:
917
+ void enforce_invariants();
918
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
919
+
920
+ private:
921
+ TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
922
+ };
923
+
924
+ inline DeviceIndex get_device(const TensorBase& self) {
925
+ return self.get_device();
926
+ }
927
+
928
+ template <typename T>
929
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
930
+ // Return the grad argument in case of a hook with void return type to have an
931
+ // std::function with Tensor return type
932
+ static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
933
+ "Expected hook to return void");
934
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
935
+ fn(grad);
936
+ return TensorBase();
937
+ });
938
+ }
939
+
940
+ template <typename T>
941
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
942
+ return _register_hook(std::forward<T>(hook));
943
+ }
944
+
945
+ namespace detail {
946
+ // Helper creator for Tensor class which doesn't requires the users to pass
947
+ // in an intrusive_ptr instead it just converts the argument passed to
948
+ // requested intrusive_ptr type.
949
+ template <typename T, typename... Args>
950
+ TensorBase make_tensor_base(Args&&... args) {
951
+ return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
952
+ }
953
+
954
+ } // namespace detail
955
+
956
+ inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
957
+ return legacyExtractDispatchKey(t.key_set());
958
+ }
959
+
960
+ } // namespace at
961
+
962
+ namespace c10 {
963
+ template <>
964
+ struct MaybeOwnedTraits<at::TensorBase> {
965
+ using owned_type = at::TensorBase;
966
+ using borrow_type = at::TensorBase;
967
+
968
+ static borrow_type createBorrow(const owned_type& from) {
969
+ // NOTE: this can be implemented without the special
970
+ // unsafe_borrow_t Tensor constructor as
971
+ //
972
+ // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
973
+ //
974
+ // but that hurts inlining due to the nullptr check in the
975
+ // Tensor(c10::intrusive_ptr<...>) constructor. We already know
976
+ // that from.impl_ isn't null because from is a valid Tensor, so
977
+ // we needn't do the check again. (using __builtin_assume can
978
+ // avoid this, but wouldn't be portable to MSVC.)
979
+ return borrow_type(borrow_type::unsafe_borrow_t{}, from);
980
+ }
981
+
982
+ static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
983
+ lhs.unsafeReleaseTensorImpl();
984
+ // See above note: this can be implemented with public API
985
+ // similarly to createBorrow(), but that would hurt inlining.
986
+ lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
987
+ }
988
+
989
+ static void destroyBorrow(borrow_type& toDestroy) {
990
+ toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
991
+ }
992
+
993
+ static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
994
+ return borrow;
995
+ }
996
+
997
+ static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
998
+ return &borrow;
999
+ }
1000
+
1001
+ static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
1002
+ return true;
1003
+ }
1004
+ };
1005
+
1006
+ template <>
1007
+ struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
1008
+ } // namespace c10
1009
+
1010
+ namespace at {
1011
+
1012
+ inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
1013
+ const std::optional<TensorBase>& opt) {
1014
+ return opt.has_value()
1015
+ ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
1016
+ : c10::MaybeOwned<TensorBase>::owned(std::in_place);
1017
+ }
1018
+
1019
+ inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
1020
+ if (is_contiguous(memory_format)) {
1021
+ return c10::MaybeOwned<TensorBase>::borrowed(*this);
1022
+ } else {
1023
+ return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
1024
+ }
1025
+ }
1026
+
1027
+ namespace symint {
1028
+
1029
+ template <typename T>
1030
+ using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
1031
+ template <typename T>
1032
+ using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
1033
+
1034
+ template <typename T, typename = enable_if_symint<T>>
1035
+ c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
1036
+ template <typename T, typename = enable_if_int<T>>
1037
+ IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
1038
+
1039
+ template <typename T, typename = enable_if_symint<T>>
1040
+ c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
1041
+ template <typename T, typename = enable_if_int<T>>
1042
+ int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
1043
+
1044
+ template <typename T, typename = enable_if_symint<T>>
1045
+ c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
1046
+ template <typename T, typename = enable_if_int<T>>
1047
+ IntArrayRef strides(const TensorBase& t) { return t.strides(); }
1048
+
1049
+ template <typename T, typename = enable_if_symint<T>>
1050
+ c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
1051
+ template <typename T, typename = enable_if_int<T>>
1052
+ int64_t numel(const TensorBase& t) { return t.numel(); }
1053
+
1054
+ } // namespace symint
1055
+
1056
+ } // namespace at
phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h ADDED
The diff for this file is too large to render. See raw diff
 
phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/dispatch/Dispatcher.h>
4
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <torch/library.h>
7
+ #include <optional>
8
+
9
+ namespace at::impl {
10
+
11
+ TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
12
+ TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
13
+ TORCH_API bool tensorlist_has_dispatch(
14
+ const c10::List<std::optional<at::Tensor>>& li);
15
+ using c10::impl::dispatch_mode_enabled;
16
+
17
+ } // namespace at::impl
phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/NumericUtils.h>
2
+ #include <c10/macros/Macros.h>
3
+ #include <c10/util/Half.h>
4
+ #include <c10/util/BFloat16.h>
5
+ #include <c10/util/MathConstants.h>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <cassert>
9
+ #include <limits>
10
+ #include <type_traits>
11
+
12
+ namespace at {
13
+
14
+ // Using DistAccumType in accumulate types for distributions.
15
+ // Note: Ideally we'd be using ATen/AccumulateType.h but looks
16
+ // like the there is some inconsistency in how accumulate types
17
+ // are mapped currently, e.g. for the cpu side, float is mapped
18
+ // to double.
19
+ template <typename T>
20
+ struct DistAccumType { };
21
+
22
+ #if defined(__CUDACC__) || defined(__HIPCC__)
23
+ template <> struct DistAccumType<half> { using type = float; };
24
+ #endif
25
+ template <> struct DistAccumType<BFloat16> { using type = float; };
26
+ template <> struct DistAccumType<Half> { using type = float; };
27
+ template <> struct DistAccumType<float> { using type = float; };
28
+ template <> struct DistAccumType<double> { using type = double; };
29
+
30
+ template <typename T>
31
+ using dist_acctype = typename DistAccumType<T>::type;
32
+
33
+ namespace transformation {
34
+
35
+ /**
36
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
37
+ * `range` is `to - from`
38
+ * `base` is `from`
39
+ */
40
+ template <typename T, typename V>
41
+ C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) {
42
+ return static_cast<T>(static_cast<int64_t>((val % range) + base));
43
+ }
44
+
45
+ /**
46
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
47
+ */
48
+ template <typename T, typename V>
49
+ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
50
+ return static_cast<T>(static_cast<int64_t>(val));
51
+ }
52
+
53
+ /**
54
+ * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
55
+ * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
56
+ * in this overloaded version
57
+ */
58
+ template <typename T, typename V>
59
+ C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
60
+ if constexpr (std::is_same_v<T, bool>) {
61
+ return static_cast<bool>(val & 1);
62
+ } else if constexpr (std::is_same_v<T, int64_t>) {
63
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
64
+ } else if constexpr (std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>) {
65
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
66
+ } else if constexpr (std::is_integral_v<T>) {
67
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
68
+ } else {
69
+ assert(false);
70
+ return 0;
71
+ }
72
+ }
73
+
74
+ /**
75
+ * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
76
+ * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
77
+ */
78
+ template<typename T, typename V>
79
+ C10_HOST_DEVICE inline std::enable_if_t<std::is_floating_point_v<T>, T>uniform_int(V val) {
80
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
81
+ }
82
+
83
+ template <typename T, typename V>
84
+ C10_HOST_DEVICE inline dist_acctype<T> uniform_real(V val, T from, T to) {
85
+ constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
86
+ constexpr auto DIVISOR = static_cast<dist_acctype<T>>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
87
+ dist_acctype<T> x = (val & MASK) * DIVISOR;
88
+ return (x * (to - from) + from);
89
+ }
90
+
91
+ /**
92
+ * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to
93
+ * normally distributed with `mean` and standard deviation `std`.
94
+ */
95
+ template <typename T>
96
+ C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
97
+ return val * std + mean;
98
+ }
99
+
100
+ /**
101
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
102
+ * Cauchy distribution with location parameter `median` and scale parameter `sigma`.
103
+ */
104
+ template <typename T>
105
+ C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
106
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
107
+ // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps),
108
+ // thus we clip those values.
109
+ constexpr T eps = std::numeric_limits<T>::epsilon();
110
+ constexpr T one_minus_eps = 1 - eps;
111
+ constexpr T zero_plus_eps = 0 + eps;
112
+ val = (val > one_minus_eps ? one_minus_eps : val);
113
+ val = (val < zero_plus_eps ? zero_plus_eps : val);
114
+ return median + sigma * at::tan(c10::pi<T> * (val - static_cast<T>(0.5)));
115
+ }
116
+
117
+ template <>
118
+ C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
119
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
120
+ return median + sigma * at::tan(c10::pi<double> * (val - static_cast<double>(0.5)));
121
+ }
122
+
123
+ /**
124
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
125
+ * exponentially distributed with `lambda` parameter of the distribution.
126
+ */
127
+ template <typename T>
128
+ C10_HOST_DEVICE inline T exponential(T val, T lambda) {
129
+ // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
130
+ // Different implementations for CUDA and CPU to preserve original logic
131
+ // TODO: must be investigated and unified!!!
132
+ // https://github.com/pytorch/pytorch/issues/38662
133
+ #if defined(__CUDACC__) || defined(__HIPCC__)
134
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
135
+ // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
136
+ // we need log to be not 0, and not underflow when converted to half
137
+ // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
138
+ auto log = val >= static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2
139
+ ? -std::numeric_limits<T>::epsilon() / 2
140
+ : at::log(val);
141
+ return static_cast<T>(-1.0) / lambda * log;
142
+ #else
143
+ return static_cast<T>(-1.0) / lambda * at::log1p(-val);
144
+ #endif
145
+ }
146
+
147
+ /**
148
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
149
+ * geometrically distributed with success probability `p`.
150
+ */
151
+ template <typename T>
152
+ C10_HOST_DEVICE inline T geometric(T val, T p) {
153
+ // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions
154
+ return static_cast<T>(::ceil(at::log(val) / at::log1p(-p)));
155
+ }
156
+
157
+ /**
158
+ * Transforms normally distributed `val` to log-normally distributed.
159
+ */
160
+ template <typename T>
161
+ C10_HOST_DEVICE inline T log_normal(T val) {
162
+ // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles
163
+ return at::exp(val);
164
+ }
165
+
166
+ /**
167
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
168
+ * bernoulli distributed with success probability `p`.
169
+ */
170
+ template <typename T>
171
+ C10_HOST_DEVICE inline T bernoulli(T val, T p) {
172
+ return val < p;
173
+ }
174
+
175
+ }} // namespace at::transformation
phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/UndefinedTensorImpl.h>
phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+
6
+ inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
7
+ auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
8
+ if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
9
+ c10::raw::intrusive_ptr::incref(tensor_impl.get());
10
+ }
11
+ return Tensor(std::move(tensor_impl));
12
+ }
13
+
14
+ inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
15
+ if (retain && th_pointer) {
16
+ c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
17
+ }
18
+ return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
19
+ }
20
+
21
+ }
phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <c10/macros/Export.h>
5
+
6
+ // A little explanation about why this file exists at all. We have
7
+ // a few methods on Tensor class which require access to reified access to
8
+ // AutogradMeta. In open source, this isn't a big deal: we just access
9
+ // torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
10
+ // we can put the definitions inline. This is because everything gets balled
11
+ // into a single dynamic library in the end.
12
+ //
13
+ // However, inside our Facebook internal version of our build system, we
14
+ // have a split between aten and torch/csrc. So we cannot simply just
15
+ // cross this boundary. "Now wait," you might say, "Why don't we just
16
+ // merge the libraries inside Facebook". Well, the problem is that there
17
+ // are some downstream applications which are at binary size limit, and
18
+ // incorporating all of the extra code from libtorch would push them
19
+ // over (admarket/adreview/service:adreviewservice, see also
20
+ // https://github.com/pytorch/pytorch/pull/29299) So if you want to do that,
21
+ // we have to fix all of the services like this.
22
+ //
23
+ // I didn't want to block eliminating Tensor-Variable on this work, so I
24
+ // had to introduce another dynamic dispatch to get to the variable
25
+ // implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
26
+ //
27
+ // I also considered using our existing dynamic dispatch mechanism, c10
28
+ // dispatcher, to do this. However, (1) some of the functions on Tensor
29
+ // have weird signatures that are not supported by autograd, and (2)
30
+ // see this bug https://github.com/pytorch/pytorch/issues/30102
31
+
32
+ namespace torch::autograd {
33
+
34
+ struct Node;
35
+
36
+ } // namespace torch::autograd
37
+
38
+ namespace at::impl {
39
+
40
+ struct TORCH_API VariableHooksInterface {
41
+ virtual ~VariableHooksInterface() = default;
42
+ virtual TensorBase tensor_data(const TensorBase&) const = 0;
43
+ virtual TensorBase variable_data(const TensorBase&) const = 0;
44
+ virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(
45
+ const TensorBase&) const = 0;
46
+ virtual unsigned _register_hook(
47
+ const TensorBase&,
48
+ std::function<TensorBase(const TensorBase&)> hook) const = 0;
49
+ virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
50
+ virtual bool is_view(const TensorBase&) const = 0;
51
+ virtual const TensorBase& base(const TensorBase&) const = 0;
52
+ virtual const std::string& name(const TensorBase&) const = 0;
53
+ virtual bool is_leaf(const TensorBase&) const = 0;
54
+ virtual int64_t output_nr(const TensorBase&) const = 0;
55
+ virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
56
+ virtual TensorBase data(const TensorBase&) const = 0;
57
+ virtual int64_t _version(const TensorBase&) const = 0;
58
+ virtual void retain_grad(const TensorBase&) const = 0;
59
+ virtual bool retains_grad(const TensorBase&) const = 0;
60
+ virtual void _backward(
61
+ const Tensor&,
62
+ TensorList,
63
+ const std::optional<Tensor>&,
64
+ std::optional<bool>,
65
+ bool) const = 0;
66
+ virtual void requires_grad_(const TensorBase&, bool) const = 0;
67
+ virtual void basic_autograd_not_implemented_fallback(
68
+ const c10::OperatorHandle& op,
69
+ c10::DispatchKeySet dispatch_keys,
70
+ torch::jit::Stack* stack) const = 0;
71
+ };
72
+
73
+ TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
74
+ TORCH_API VariableHooksInterface* GetVariableHooks();
75
+ TORCH_API bool HasVariableHooks();
76
+
77
+ struct TORCH_API VariableHooksRegisterer {
78
+ explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
79
+ SetVariableHooks(hooks);
80
+ }
81
+ };
82
+
83
+ } // namespace at::impl
phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <utility>
4
+
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <ATen/core/List.h>
7
+
8
+ namespace at {
9
+
10
+ // This class allows you to write variadic functions which
11
+ // call a (possibly overloaded) function on each argument,
12
+ // in order. This is most commonly used in autogenerated code,
13
+ // where it is convenient to have a function that can uniformly
14
+ // take arguments of different types. If your arguments
15
+ // are homogenous consider using a std::initializer_list instead.
16
+ //
17
+ // For examples of this in use, see torch/csrc/utils/variadic.h
18
+ template <typename F>
19
+ struct IterArgs {
20
+ template <typename... Args>
21
+ inline F& apply() {
22
+ return self();
23
+ }
24
+
25
+ // NB: Use perfect forwarding here, otherwise we'll make value
26
+ // copies of all arguments!
27
+ template <typename T, typename... Args>
28
+ inline F& apply(T&& arg, Args&&... args) {
29
+ self()(std::forward<T>(arg));
30
+ if (self().short_circuit()) {
31
+ return self();
32
+ } else {
33
+ return apply(std::forward<Args>(args)...);
34
+ }
35
+ }
36
+
37
+ // Here are some handy overloads which provide sensible
38
+ // defaults for container-like structures that one might
39
+ // be interested in recursing into. You can enable them
40
+ // by adding:
41
+ //
42
+ // using IterArgs<YourStructName>::operator()
43
+ //
44
+ // to your struct. These are not enabled by default because
45
+ // you may be able to process these structures more efficiently
46
+ // than handling them one-by-one.
47
+
48
+ template <typename T>
49
+ void operator()(c10::IListRef<T> args) {
50
+ for (const auto& arg : args) {
51
+ self()(arg);
52
+ if (self().short_circuit())
53
+ return;
54
+ }
55
+ }
56
+
57
+ template <typename T>
58
+ void operator()(at::ArrayRef<T> args) {
59
+ for (const auto& arg : args) {
60
+ self()(arg);
61
+ if (self().short_circuit())
62
+ return;
63
+ }
64
+ }
65
+
66
+ template <typename T>
67
+ void operator()(const torch::List<T>& args) {
68
+ for (const auto& arg : args) {
69
+ self()(arg);
70
+ if (self().short_circuit())
71
+ return;
72
+ }
73
+ }
74
+
75
+ // NB: we need to specify std::vector manually as C++ won't
76
+ // do an implicit conversion to make a template deduction go through.
77
+ template <typename T>
78
+ void operator()(const std::vector<T>& args) {
79
+ self()(at::ArrayRef<T>{args});
80
+ }
81
+
82
+ constexpr bool short_circuit() const {
83
+ return false;
84
+ }
85
+
86
+ private:
87
+ inline F& self() {
88
+ return *static_cast<F*>(this);
89
+ }
90
+ };
91
+
92
+ } // namespace torch
phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ostream>
3
+ #include <sstream>
4
+ #include <unordered_map>
5
+
6
+ #include <c10/core/impl/LocalDispatchKeySet.h>
7
+
8
+ namespace at::vitals {
9
+
10
+ TORCH_API bool torchVitalEnabled();
11
+
12
+ struct TORCH_API TorchVitalAttr {
13
+ // always initialized to empty
14
+ std::string value;
15
+ template <typename T>
16
+ TorchVitalAttr& operator<<(const T& t) {
17
+ if (torchVitalEnabled()) {
18
+ std::stringstream ss;
19
+ ss << t;
20
+ value += ss.str();
21
+ }
22
+ return *this;
23
+ }
24
+
25
+ template <typename T>
26
+ void write(const T& t, bool force) {
27
+ if (force || torchVitalEnabled()) {
28
+ std::stringstream ss;
29
+ ss << t;
30
+ value = ss.str();
31
+ }
32
+ }
33
+ };
34
+
35
+ struct TORCH_API TorchVital {
36
+ std::string name;
37
+ std::unordered_map<std::string, TorchVitalAttr> attrs;
38
+
39
+ explicit TorchVital(std::string n) : name(std::move(n)) {}
40
+ TorchVital(const TorchVital&) = default;
41
+ TorchVital(TorchVital&&) = default;
42
+ TorchVital& operator=(const TorchVital&) = default;
43
+ TorchVital& operator=(TorchVital&&) = default;
44
+ TorchVital() = delete;
45
+
46
+ TorchVitalAttr& create(const std::string& attr);
47
+ TorchVitalAttr& create(const std::string& attr, bool force);
48
+ friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
49
+
50
+ ~TorchVital();
51
+ };
52
+
53
+ std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
54
+
55
+ // A way to access vitals by string names instead of by global reference.
56
+ // This enables access to vitals from the PythonAPI.
57
+ class TORCH_API APIVitals {
58
+ public:
59
+ bool vitals_enabled;
60
+
61
+ // Set any vital sign that was added to the map.
62
+ bool setVital(
63
+ const std::string& vital_name,
64
+ const std::string& attr_name,
65
+ const std::string& value,
66
+ bool force = false);
67
+ std::string readVitals();
68
+
69
+ APIVitals();
70
+
71
+ // Ensure this stays a singleton
72
+ APIVitals(APIVitals const& other) = delete;
73
+ APIVitals(APIVitals&& other) = delete;
74
+ APIVitals& operator=(const APIVitals&) = delete;
75
+ APIVitals& operator=(APIVitals&&) = delete;
76
+ ~APIVitals() = default;
77
+
78
+ private:
79
+ std::unordered_map<std::string, TorchVital> name_map_;
80
+ };
81
+
82
+ extern TORCH_API APIVitals VitalsAPI;
83
+
84
+ } // namespace at::vitals
85
+
86
+ #define TORCH_VITAL_DECLARE(name) \
87
+ TORCH_API at::vitals::TorchVital TorchVital_##name;
88
+
89
+ #define TORCH_VITAL_DEFINE(name) \
90
+ TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
91
+
92
+ #define TORCH_VITAL_BASE(name) TorchVital_##name
93
+
94
+ #define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/boxing/OperatorKernel.h>
4
+ #include <c10/core/DispatchKeySet.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace c10 {
8
+
9
+ struct IValue;
10
+ using Stack = std::vector<IValue>;
11
+
12
+ class OperatorHandle;
13
+ class KernelFunction;
14
+
15
+ // This kernel implements the behavior of falling through to the next available
16
+ // registered dispatch key. The implementation of this function is FAST; it is
17
+ // no overhead to fallthrough to the next key. See cpp file for some more
18
+ // implementation notes; notably, this does NOT actually go through the
19
+ // boxing/unboxing codepath.
20
+ TORCH_API void fallthrough_kernel(
21
+ OperatorKernel*,
22
+ const OperatorHandle&,
23
+ DispatchKeySet,
24
+ Stack*);
25
+
26
+ // Note [Ambiguity in AutogradOther kernel]
27
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28
+ // This error-reporting kernel is registered to the AutogradOther entry in the
29
+ // dispatch table when there is both a CompositeImplicitAutograd kernel and a
30
+ // backend kernel for ANY backend that maps to AutogradOther. To see why
31
+ // this is necessary in the AutogradOther case, it's helpful to first see
32
+ // why everything works out fine for a backend that has a reserved Autograd
33
+ // entry (see rule 2.2 in [Note] DispatchTable computation):
34
+ //
35
+ // CPU AutogradCPU
36
+ // reg? registers with...
37
+ // -------------------------------------------------
38
+ // y Autograd registration takes precedence
39
+ // over CompositeImplicitAutograd.
40
+ // This is good, because the CPU specific backend
41
+ // implementation is more specialized and typically better;
42
+ // if we used the composite, we would bypass it.
43
+ // (NB: the Autograd key is guaranteed to exist because
44
+ // the autograd codegen requires it!)
45
+ //
46
+ // n CompositeImplicitAutograd takes precedence.
47
+ // This is also good, because the Autograd
48
+ // registration (if it exists) would try to redispatch
49
+ // to the (non-existent) CPU implementation; by
50
+ // using the composite, we ensure the operator
51
+ // actually works.
52
+ //
53
+ // As you can see, when we have a specific Autograd key (AutogradCPU), we can
54
+ // decide whether or not to use the CompositeImplicitAutograd kernel or the
55
+ // Autograd kernel based on whether or not the backend kernel exists.
56
+ //
57
+ // However, for AutogradOther (which is the catchall autograd kernel for
58
+ // everything that doesn't have a specific Autograd key), we can't do this
59
+ // trick because there isn't any unique backend to peek at to disambiguate;
60
+ // if there are some backends that have implementations they prefer Autograd,
61
+ // but unimplemented backends would prefer CompositeImplicitAutograd. Rather
62
+ // than arbitrarily pick one or the other, we just register a kernel that raises
63
+ // an error and let the user decide how to proceed.
64
+ TORCH_API void ambiguous_autogradother_kernel(
65
+ OperatorKernel*,
66
+ const OperatorHandle&,
67
+ DispatchKeySet,
68
+ Stack*);
69
+
70
+ // Note [named_not_supported_kernel]
71
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72
+ // This kernel implements reporting an error message saying that named tensor is
73
+ // not supported. This kernel doesn't rely on the Stack, and so it is special
74
+ // cased in the dispatcher to be triggered before we attempt boxing (so we can
75
+ // give a good error message in cases when boxing is not supported). When
76
+ // boxing is universally supported this can be removed.
77
+ [[noreturn]] TORCH_API void named_not_supported_kernel(
78
+ OperatorKernel*,
79
+ const OperatorHandle&,
80
+ DispatchKeySet,
81
+ Stack*);
82
+
83
+ /**
84
+ * BoxedKernel is similar to a std::function storing a boxed kernel.
85
+ */
86
+ class TORCH_API BoxedKernel final {
87
+ public:
88
+ // This is how boxed kernels are actually stored
89
+ //
90
+ // Note [Plumbing Keys Through The Dispatcher]
91
+ // Benchmarks have shown that it is expensive for the dispatcher to read from
92
+ // thread-local storage (TLS) upon every dispatch call into order to compute
93
+ // which kernel to dispatch to.
94
+ //
95
+ // To mitigate this, we've updated the calling convention inside the
96
+ // dispatcher to expect every kernel that it stores to have a first argument
97
+ // of type DispatchKeySet.
98
+ //
99
+ // What are the invariants of the DispatchKeySet when it gets passed to a
100
+ // kernel?
101
+ // - All keys to the left of the current dispatch key have been masked out.
102
+ // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the
103
+ // highest bit to be DispatchKey::Tracer)
104
+ // - All other keys that dispatcher normally would have computed through TLS +
105
+ // global state + op arguments
106
+ // are still in the set.
107
+ //
108
+ // Kernels can then opt into using this keyset to save the dispatcher from
109
+ // doing repeated work during redispatches: recalculating the highest-priority
110
+ // dispatch key, which involves reading from TLS. Instead, the kernels that
111
+ // opt in will calculate an updated DispatchKeySet directly from the old one,
112
+ // and pass the updated set directly into the dispatcher upon redispatching.
113
+ //
114
+ // This is an opt-in mechanism: Kernels can automatically opt in by setting
115
+ // the first argument in their signature to be of type DispatchKeySet. See the
116
+ // kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for
117
+ // examples.
118
+ //
119
+ // The mechanism for optionally passing that DispatchKeySet into the kernel
120
+ // lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through
121
+ // The Dispatcher 2] for details.
122
+ using InternalBoxedKernelFunction =
123
+ void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
124
+ // This is the public API for how boxed kernels are defined
125
+ using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
126
+ using BoxedKernelFunction_withDispatchKeys =
127
+ void(const OperatorHandle&, DispatchKeySet, Stack*);
128
+
129
+ BoxedKernel();
130
+
131
+ // Fast path for dispatch to allow not touching the boxed kernel in
132
+ // the common case where unboxed is available.
133
+ bool isValid() const;
134
+ bool isFallthrough() const;
135
+
136
+ /**
137
+ * Call the function with boxed arguments.
138
+ */
139
+ void callBoxed(
140
+ const OperatorHandle& opHandle,
141
+ DispatchKeySet dispatchKeySet,
142
+ Stack* stack) const;
143
+
144
+ /**
145
+ * Create a KernelFunction from a boxed function.
146
+ *
147
+ * Example:
148
+ *
149
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
150
+ * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
151
+ */
152
+ template <BoxedKernelFunction* func>
153
+ static BoxedKernel makeFromFunction();
154
+
155
+ /**
156
+ * TODO: This will only be useful if we write a backend fallback that plumbs
157
+ * dispatch keys (currently there are none) See Note [Plumbing Keys Through
158
+ * The Dispatcher] for details.
159
+ */
160
+ template <BoxedKernelFunction_withDispatchKeys* func>
161
+ static BoxedKernel makeFromFunction();
162
+
163
+ /**
164
+ * Create a KernelFunction from a boxed functor.
165
+ *
166
+ * Example:
167
+ *
168
+ * > class MyFunctor final : public c10::OperatorKernel {
169
+ * > public:
170
+ * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
171
+ * > };
172
+ * > BoxedKernel func =
173
+ * BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
174
+ */
175
+ template <class KernelFunctor>
176
+ static BoxedKernel makeFromFunctor(
177
+ std::unique_ptr<KernelFunctor> kernelFunctor);
178
+
179
+ static BoxedKernel makeFallthrough();
180
+ static BoxedKernel makeAmbiguousAutogradOther();
181
+ static BoxedKernel makeNamedNotSupported();
182
+
183
+ private:
184
+ friend class KernelFunction;
185
+
186
+ template <BoxedKernelFunction* func>
187
+ static void make_boxed_function(
188
+ OperatorKernel*,
189
+ const OperatorHandle& opHandle,
190
+ DispatchKeySet,
191
+ Stack* stack);
192
+
193
+ template <BoxedKernelFunction_withDispatchKeys* func>
194
+ static void make_boxed_function(
195
+ OperatorKernel*,
196
+ const OperatorHandle& opHandle,
197
+ DispatchKeySet,
198
+ Stack* stack);
199
+
200
+ explicit BoxedKernel(
201
+ std::unique_ptr<OperatorKernel> functor,
202
+ InternalBoxedKernelFunction* boxed_kernel_func);
203
+
204
+ OperatorKernel* getFunctor() const;
205
+ InternalBoxedKernelFunction* getFnPtr() const;
206
+
207
+ c10::intrusive_ptr<OperatorKernel> functor_;
208
+ InternalBoxedKernelFunction* boxed_kernel_func_;
209
+ };
210
+
211
+ } // namespace c10
212
+
213
+ #include <ATen/core/boxing/BoxedKernel_impl.h>
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace c10 {
4
+
5
+ inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {}
6
+
7
+ inline BoxedKernel::BoxedKernel(
8
+ std::unique_ptr<OperatorKernel> functor,
9
+ InternalBoxedKernelFunction* boxed_kernel_func)
10
+ : functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {}
11
+
12
+ template <BoxedKernel::BoxedKernelFunction* func>
13
+ inline void BoxedKernel::make_boxed_function(
14
+ OperatorKernel*,
15
+ const OperatorHandle& opHandle,
16
+ DispatchKeySet,
17
+ Stack* stack) {
18
+ // Note that we're dropping the DispatchKeySet argument.
19
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
20
+ func(opHandle, stack);
21
+ }
22
+
23
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
24
+ inline void BoxedKernel::make_boxed_function(
25
+ OperatorKernel*,
26
+ const OperatorHandle& opHandle,
27
+ DispatchKeySet ks,
28
+ Stack* stack) {
29
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
30
+ func(opHandle, ks, stack);
31
+ }
32
+
33
+ inline bool BoxedKernel::isValid() const {
34
+ return boxed_kernel_func_ != nullptr;
35
+ }
36
+
37
+ inline bool BoxedKernel::isFallthrough() const {
38
+ return boxed_kernel_func_ == &fallthrough_kernel;
39
+ }
40
+
41
+ inline void BoxedKernel::callBoxed(
42
+ const OperatorHandle& opHandle,
43
+ DispatchKeySet dispatchKeySet,
44
+ Stack* stack) const {
45
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
46
+ boxed_kernel_func_ != nullptr,
47
+ "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel.");
48
+ (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
49
+ }
50
+
51
+ template <BoxedKernel::BoxedKernelFunction* func>
52
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
53
+ return BoxedKernel(
54
+ nullptr, // no functor_ object
55
+ &make_boxed_function<func>);
56
+ }
57
+
58
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
59
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
60
+ return BoxedKernel(
61
+ nullptr, // no functor_ object
62
+ &make_boxed_function<func>);
63
+ }
64
+
65
+ inline BoxedKernel BoxedKernel::makeFallthrough() {
66
+ return BoxedKernel(
67
+ nullptr, // no functor_ object
68
+ &fallthrough_kernel);
69
+ }
70
+
71
+ inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
72
+ return BoxedKernel(
73
+ nullptr, // no functor_ object
74
+ &ambiguous_autogradother_kernel);
75
+ }
76
+
77
+ inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
78
+ return BoxedKernel(
79
+ nullptr, // no functor_ object
80
+ &named_not_supported_kernel);
81
+ }
82
+
83
+ template <class KernelFunctor>
84
+ inline BoxedKernel BoxedKernel::makeFromFunctor(
85
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
86
+ static_assert(
87
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
88
+ "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
89
+ return BoxedKernel(
90
+ std::move(kernelFunctor),
91
+ [](OperatorKernel* kernel,
92
+ const OperatorHandle& op,
93
+ DispatchKeySet ks,
94
+ Stack* stack) {
95
+ (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
96
+ });
97
+ }
98
+
99
+ inline OperatorKernel* BoxedKernel::getFunctor() const {
100
+ return functor_.get();
101
+ }
102
+ inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
103
+ return boxed_kernel_func_;
104
+ }
105
+
106
+ } // namespace c10
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ATen_fwd.h>
4
+ #include <ATen/core/boxing/BoxedKernel.h>
5
+ #include <ATen/core/stack.h>
6
+ #include <c10/core/DispatchKeySet.h>
7
+ #include <c10/util/TypeList.h>
8
+ #include <c10/util/intrusive_ptr.h>
9
+ #include <type_traits>
10
+
11
+ namespace c10 {
12
+
13
+ using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
14
+ // to the c10 namespace.
15
+
16
+ class OperatorHandle;
17
+ struct OperatorKernel;
18
+ class KernelFunction;
19
+
20
+ template <typename T>
21
+ using has_symint = std::disjunction<
22
+ std::is_same<c10::SymInt, T>,
23
+ std::is_same<c10::SymIntArrayRef, T>,
24
+ std::is_same<at::OptionalSymIntArrayRef, T>,
25
+ std::is_same<std::optional<c10::SymInt>, T>>;
26
+
27
+ template <typename T>
28
+ struct remove_symint {
29
+ using type = T;
30
+ };
31
+
32
+ template <>
33
+ struct remove_symint<c10::SymInt> {
34
+ using type = int64_t;
35
+ };
36
+
37
+ template <>
38
+ struct remove_symint<at::OptionalSymIntArrayRef> {
39
+ using type = OptionalIntArrayRef;
40
+ };
41
+
42
+ template <>
43
+ struct remove_symint<c10::SymIntArrayRef> {
44
+ using type = c10::IntArrayRef;
45
+ };
46
+
47
+ template <>
48
+ struct remove_symint<std::optional<c10::SymInt>> {
49
+ using type = std::optional<int64_t>;
50
+ };
51
+
52
+ template <bool symint, typename T>
53
+ struct maybe_keep_symint final {};
54
+
55
+ template <typename T>
56
+ struct maybe_keep_symint<true, T> {
57
+ using type = T;
58
+ };
59
+
60
+ template <typename T>
61
+ struct maybe_keep_symint<false, T> {
62
+ using type = typename remove_symint<T>::type;
63
+ };
64
+
65
+ template <typename T>
66
+ using fn_has_symint = typename guts::typelist::true_for_any_type<
67
+ has_symint,
68
+ typename guts::infer_function_traits<T>::type::parameter_types>;
69
+
70
+ template <typename T>
71
+ struct fn_remove_symint;
72
+
73
+ template <typename Ret, typename... Args>
74
+ struct fn_remove_symint<Ret(Args...)> {
75
+ using type = Ret(typename remove_symint<Args>::type...);
76
+ };
77
+
78
+ /**
79
+ * KernelFunction is similar to std::function but stores a kernel function.
80
+ * You can create a KernelFunction from a boxed or unboxed
81
+ * function/functor/lambda and call it in a boxed or unboxed way. If the way it
82
+ * was created doesn't match the way it was called, it will do boxing or
83
+ * unboxing as necessary.
84
+ */
85
+ class TORCH_API KernelFunction final {
86
+ public:
87
+ using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
88
+ using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
89
+ using BoxedKernelFunction_withDispatchKeys =
90
+ BoxedKernel::BoxedKernelFunction_withDispatchKeys;
91
+
92
+ KernelFunction();
93
+
94
+ // Fast path for dispatch to allow not touching the boxed kernel in
95
+ // the common case where unboxed is available.
96
+ bool isValidUnboxed() const;
97
+ bool isValidSymUnboxed() const;
98
+ bool isValid() const;
99
+ bool isFallthrough() const;
100
+
101
+ /**
102
+ * Call the function in a boxed way.
103
+ * If the kernel function was created with an unboxed function,
104
+ * this will call an unboxing wrapper which then calls into that
105
+ * unboxed function.
106
+ *
107
+ * Example:
108
+ *
109
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
110
+ * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
111
+ * > Tensor result = func.callBoxed(stack);
112
+ *
113
+ * Or, with an unboxed implementation:
114
+ *
115
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
116
+ * > [] (Tensor a, bool b) -> Tensor {...});
117
+ * > Tensor result = func.callBoxed(stack);
118
+ */
119
+ void callBoxed(
120
+ const OperatorHandle& opHandle,
121
+ DispatchKeySet dispatchKeySet,
122
+ Stack* stack) const;
123
+
124
+ /**
125
+ * Call the function in an unboxed way.
126
+ * If the kernel function was created with a boxed function,
127
+ * this will box all inputs and then call into that boxed function.
128
+ *
129
+ * Note that this doesn't work for all types yet.
130
+ *
131
+ * Example:
132
+ *
133
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
134
+ * > [] (Tensor a, bool b) -> Tensor {...});
135
+ * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
136
+ *
137
+ * Or, with a boxed implementation:
138
+ *
139
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
140
+ * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
141
+ * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
142
+ */
143
+ template <class Return, class... Args>
144
+ Return call(
145
+ const OperatorHandle& opHandle,
146
+ DispatchKeySet dispatchKeySet,
147
+ Args... args) const;
148
+
149
+ /**
150
+ * Create a KernelFunction from a BoxedKernel.
151
+ */
152
+ static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
153
+
154
+ /**
155
+ * Create a KernelFunction from a boxed function.
156
+ *
157
+ * Example:
158
+ *
159
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
160
+ * > KernelFunction func =
161
+ * KernelFunction::makeFromBoxedFunction<&boxed_func>();
162
+ */
163
+ template <BoxedKernelFunction* func>
164
+ static KernelFunction makeFromBoxedFunction();
165
+
166
+ /**
167
+ * TODO: This will only be useful if we write a backend fallback that plumbs
168
+ * dispatch keys (currently there are none) See Note [Plumbing Keys Through
169
+ * The Dispatcher] for details.
170
+ */
171
+ template <BoxedKernelFunction_withDispatchKeys* func>
172
+ static KernelFunction makeFromBoxedFunction();
173
+
174
+ /**
175
+ * Create a KernelFunction from an unboxed functor.
176
+ *
177
+ * Example:
178
+ *
179
+ * > class MyFunctor final : public c10::OperatorKernel {
180
+ * > public:
181
+ * > Tensor operator()(Tensor a, Tensor b) {...}
182
+ * > };
183
+ * > KernelFunction func =
184
+ * KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
185
+ */
186
+ template <bool AllowLegacyTypes = false, class KernelFunctor>
187
+ static KernelFunction makeFromUnboxedFunctor(
188
+ std::unique_ptr<OperatorKernel> kernelFunctor);
189
+
190
+ /**
191
+ * Create a KernelFunction from a boxed functor.
192
+ *
193
+ * Example:
194
+ *
195
+ * > class MyFunctor final : public c10::OperatorKernel {
196
+ * > public:
197
+ * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
198
+ * > };
199
+ * > KernelFunction func =
200
+ * KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
201
+ */
202
+ template <class KernelFunctor>
203
+ static KernelFunction makeFromBoxedFunctor(
204
+ std::unique_ptr<KernelFunctor> kernelFunctor);
205
+
206
+ /**
207
+ * Create a KernelFunction from an unboxed function.
208
+ * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
209
+ * because knowing the function pointer as a template argument (i.e. at
210
+ * compile time) allows the compiler to inline the function into its
211
+ * unboxing wrapper and yields better performance when calling the function.
212
+ *
213
+ * Example:
214
+ *
215
+ * > Tensor unboxed_func(Tensor a, Tensor b) {...}
216
+ * > KernelFunction func =
217
+ * KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func),
218
+ * &unboxed_func>();
219
+ */
220
+ template <class FuncPtr, bool AllowLegacyTypes = false>
221
+ static KernelFunction makeFromUnboxedFunction(FuncPtr);
222
+
223
+ /**
224
+ * Create a KernelFunction from an unboxed function.
225
+ * KernelFunction::makeFromUnboxedFunction is usually a better choice than
226
+ * this if you know the function pointer at compile time, see doc comment
227
+ * there for an explanation.
228
+ *
229
+ * Example:
230
+ *
231
+ * > Tensor unboxed_func(Tensor a, Tensor b) {...}
232
+ * > KernelFunction func =
233
+ * KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
234
+ */
235
+ template <bool AllowLegacyTypes = false, class FuncType>
236
+ static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
237
+
238
+ static KernelFunction makeFallthrough();
239
+ static KernelFunction makeAmbiguousAutogradOther();
240
+ static KernelFunction makeNamedNotSupported();
241
+
242
+ /**
243
+ * Create a KernelFunction from an unboxed lambda.
244
+ *
245
+ * Example:
246
+ *
247
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
248
+ * > [] (Tensor a, bool b) -> Tensor {...});
249
+ */
250
+ template <bool AllowLegacyTypes = false, class Lambda>
251
+ static std::enable_if_t<
252
+ guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
253
+ KernelFunction>
254
+ makeFromUnboxedLambda(Lambda&& lambda);
255
+ template <bool AllowLegacyTypes = false, class Lambda>
256
+ static std::enable_if_t<
257
+ !guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
258
+ KernelFunction>
259
+ makeFromUnboxedLambda(Lambda&& lambda);
260
+
261
+ std::string dumpState() const;
262
+ // For testing internal invariants only
263
+ bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
264
+
265
+ private:
266
+ explicit KernelFunction(
267
+ std::unique_ptr<OperatorKernel> functor,
268
+ InternalBoxedKernelFunction* boxed_kernel_func,
269
+ void* unboxed_kernel_func,
270
+ void* sym_unboxed_kernel_func);
271
+ explicit KernelFunction(
272
+ BoxedKernel boxed_fn,
273
+ void* unboxed_kernel_func,
274
+ void* sym_unboxed_kernel_func);
275
+
276
+ BoxedKernel boxed_kernel_func_;
277
+ void* unboxed_kernel_func_;
278
+ void* sym_unboxed_kernel_func_;
279
+ };
280
+
281
+ } // namespace c10
282
+
283
+ #include <ATen/core/boxing/KernelFunction_impl.h>
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
2
+ #include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
3
+ #include <ATen/core/boxing/impl/boxing.h>
4
+ #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
5
+
6
+ #include <c10/util/C++17.h>
7
+ #include <type_traits>
8
+
9
+ namespace c10 {
10
+
11
+ namespace detail {
12
+ template <typename Base, typename Child, typename... Args>
13
+ std::enable_if_t<
14
+ !std::is_array_v<Base> && !std::is_array_v<Child> &&
15
+ std::is_base_of_v<Base, Child>,
16
+ std::unique_ptr<Base>>
17
+ make_unique_base(Args&&... args) {
18
+ return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
19
+ }
20
+ } // namespace detail
21
+
22
+ inline KernelFunction::KernelFunction()
23
+ : boxed_kernel_func_(),
24
+ unboxed_kernel_func_(nullptr),
25
+ sym_unboxed_kernel_func_(nullptr) {}
26
+
27
+ inline KernelFunction::KernelFunction(
28
+ std::unique_ptr<OperatorKernel> functor,
29
+ InternalBoxedKernelFunction* boxed_kernel_func,
30
+ void* unboxed_kernel_func,
31
+ void* sym_unboxed_kernel_func = nullptr)
32
+ : boxed_kernel_func_(std::move(functor), boxed_kernel_func),
33
+ unboxed_kernel_func_(unboxed_kernel_func),
34
+ sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
35
+
36
+ inline KernelFunction::KernelFunction(
37
+ BoxedKernel boxed_fn,
38
+ void* unboxed_kernel_func,
39
+ void* sym_unboxed_kernel_func = nullptr)
40
+ : boxed_kernel_func_(std::move(boxed_fn)),
41
+ unboxed_kernel_func_(unboxed_kernel_func),
42
+ sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
43
+
44
+ inline bool KernelFunction::isValidUnboxed() const {
45
+ return unboxed_kernel_func_ != nullptr;
46
+ }
47
+
48
+ inline bool KernelFunction::isValidSymUnboxed() const {
49
+ return sym_unboxed_kernel_func_ != nullptr;
50
+ }
51
+
52
+ inline bool KernelFunction::isValid() const {
53
+ return boxed_kernel_func_.isValid();
54
+ }
55
+
56
+ inline bool KernelFunction::isFallthrough() const {
57
+ return boxed_kernel_func_.isFallthrough();
58
+ }
59
+
60
+ inline void KernelFunction::callBoxed(
61
+ const OperatorHandle& opHandle,
62
+ DispatchKeySet dispatchKeySet,
63
+ Stack* stack) const {
64
+ boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
65
+ }
66
+
67
+ template <class Return, class... Args>
68
+ inline Return callUnboxedKernelFunction(
69
+ void* unboxed_kernel_func,
70
+ OperatorKernel* functor,
71
+ DispatchKeySet dispatchKeySet,
72
+ Args&&... args) {
73
+ using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...);
74
+ ActualSignature* func =
75
+ reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
76
+ return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
77
+ }
78
+
79
+ // This template requires you to explicitly specify the argument you want to
80
+ // forward; it doesn't work if you try to deduce it
81
+ // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
82
+
83
+ template <typename T>
84
+ inline typename remove_symint<T>::type unpackSymInt(T x) {
85
+ return x;
86
+ }
87
+
88
+ template <>
89
+ inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
90
+ return x.guard_int(__FILE__, __LINE__);
91
+ }
92
+
93
+ template <>
94
+ inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
95
+ c10::SymIntArrayRef x) {
96
+ return C10_AS_INTARRAYREF_SLOW(x);
97
+ }
98
+
99
+ template <>
100
+ inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
101
+ std::optional<c10::SymInt> x) {
102
+ return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
103
+ : std::nullopt;
104
+ }
105
+
106
+ template <>
107
+ inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
108
+ at::OptionalSymIntArrayRef x) {
109
+ return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
110
+ : std::nullopt;
111
+ }
112
+
113
+ template <class Return, class... Args>
114
+ C10_ALWAYS_INLINE Return KernelFunction::call(
115
+ const OperatorHandle& opHandle,
116
+ DispatchKeySet dispatchKeySet,
117
+ Args... args) const {
118
+ // note: Args above is intentionally not Args&&. We don't want perfect
119
+ // forwarding, which would require Args to be deduced, but instead we
120
+ // want callers to explicitly specify the Args.
121
+
122
+ if constexpr (std::disjunction_v<has_symint<Args>...>) {
123
+ if (sym_unboxed_kernel_func_ != nullptr) {
124
+ auto* functor = boxed_kernel_func_.getFunctor();
125
+ return callUnboxedKernelFunction<Return, Args...>(
126
+ sym_unboxed_kernel_func_,
127
+ functor,
128
+ dispatchKeySet,
129
+ std::forward<Args>(args)...);
130
+ }
131
+
132
+ if (unboxed_kernel_func_ != nullptr) {
133
+ auto* functor = boxed_kernel_func_.getFunctor();
134
+ return callUnboxedKernelFunction<
135
+ Return,
136
+ typename remove_symint<Args>::type...>(
137
+ unboxed_kernel_func_,
138
+ functor,
139
+ dispatchKeySet,
140
+ unpackSymInt<Args>(args)...);
141
+ }
142
+ } else {
143
+ if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
144
+ auto* functor = boxed_kernel_func_.getFunctor();
145
+ return callUnboxedKernelFunction<Return, Args...>(
146
+ unboxed_kernel_func_,
147
+ functor,
148
+ dispatchKeySet,
149
+ std::forward<Args>(args)...);
150
+ }
151
+ }
152
+
153
+ return impl::BoxedKernelWrapper<Return(Args...)>::call(
154
+ boxed_kernel_func_,
155
+ opHandle,
156
+ dispatchKeySet,
157
+ std::forward<Args>(args)...);
158
+ }
159
+
160
+ inline KernelFunction KernelFunction::makeFromBoxedKernel(
161
+ BoxedKernel boxed_fn) {
162
+ return KernelFunction(
163
+ std::move(boxed_fn), nullptr); // no unboxed function pointer
164
+ }
165
+
166
+ template <KernelFunction::BoxedKernelFunction* func>
167
+ inline KernelFunction KernelFunction::makeFromBoxedFunction() {
168
+ return KernelFunction::makeFromBoxedKernel(
169
+ BoxedKernel::makeFromFunction<func>());
170
+ }
171
+
172
+ template <KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
173
+ inline KernelFunction KernelFunction::makeFromBoxedFunction() {
174
+ return KernelFunction::makeFromBoxedKernel(
175
+ BoxedKernel::makeFromFunction<func>());
176
+ }
177
+
178
+ inline KernelFunction KernelFunction::makeFallthrough() {
179
+ return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough());
180
+ }
181
+
182
+ inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
183
+ return KernelFunction::makeFromBoxedKernel(
184
+ BoxedKernel::makeAmbiguousAutogradOther());
185
+ }
186
+
187
+ inline KernelFunction KernelFunction::makeNamedNotSupported() {
188
+ return KernelFunction::makeFromBoxedKernel(
189
+ BoxedKernel::makeNamedNotSupported());
190
+ }
191
+
192
+ template <bool AllowLegacyTypes, class KernelFunctor>
193
+ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(
194
+ std::unique_ptr<OperatorKernel> kernelFunctor) {
195
+ #ifndef NDEBUG
196
+ // This assertion is costly for build time so it's debug-gated.
197
+ static_assert(
198
+ guts::is_functor<KernelFunctor>::value,
199
+ "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
200
+ #endif
201
+ static_assert(
202
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
203
+ "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
204
+
205
+ auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
206
+ void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
207
+ bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
208
+ return KernelFunction(
209
+ std::move(kernelFunctor),
210
+ &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::
211
+ call,
212
+ is_symint ? nullptr : void_unboxed_fn,
213
+ is_symint ? void_unboxed_fn : nullptr);
214
+ }
215
+
216
+ template <class KernelFunctor>
217
+ inline KernelFunction KernelFunction::makeFromBoxedFunctor(
218
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
219
+ return KernelFunction::makeFromBoxedKernel(
220
+ BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
221
+ }
222
+
223
+ template <class FuncPtr, bool AllowLegacyTypes>
224
+ inline KernelFunction KernelFunction::makeFromUnboxedFunction(
225
+ FuncPtr func_ptr) {
226
+ static_assert(
227
+ is_compile_time_function_pointer<FuncPtr>::value,
228
+ "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
229
+ static_assert(
230
+ !std::is_same_v<typename FuncPtr::FuncType, BoxedKernelFunction>,
231
+ "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
232
+ #if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__)
233
+ TORCH_INTERNAL_ASSERT(
234
+ FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
235
+ #else
236
+ static_assert(
237
+ FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
238
+ #endif
239
+
240
+ #if !defined(C10_MOBILE)
241
+ (void)func_ptr; // Suppress unused variable warning
242
+ return makeFromUnboxedFunctor<
243
+ AllowLegacyTypes,
244
+ typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
245
+ detail::make_unique_base<
246
+ OperatorKernel,
247
+ typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>());
248
+ #else
249
+ // On mobile, we rather want to optimize for binary size than for performance,
250
+ // so let's not inline the kernel into the wrapper but use
251
+ // makeFromUnboxedRuntimeFunction instead.
252
+ return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
253
+ #endif
254
+ }
255
+
256
+ template <bool AllowLegacyTypes, class FuncType>
257
+ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(
258
+ FuncType* func) {
259
+ static_assert(
260
+ guts::is_function_type<FuncType>::value,
261
+ "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
262
+ static_assert(
263
+ !std::is_same_v<FuncType, BoxedKernelFunction>,
264
+ "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
265
+ TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
266
+
267
+ return makeFromUnboxedFunctor<
268
+ AllowLegacyTypes,
269
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
270
+ detail::make_unique_base<
271
+ OperatorKernel,
272
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func));
273
+ }
274
+
275
+ template <bool AllowLegacyTypes, class Lambda>
276
+ inline std::enable_if_t<
277
+ guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
278
+ KernelFunction>
279
+ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
280
+ static_assert(
281
+ guts::is_functor<std::decay_t<Lambda>>::value,
282
+ "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
283
+
284
+ #if !defined(C10_MOBILE)
285
+ return makeFromUnboxedFunctor<
286
+ AllowLegacyTypes,
287
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
288
+ detail::make_unique_base<
289
+ OperatorKernel,
290
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
291
+ std::forward<Lambda>(lambda)));
292
+ #else
293
+ // On mobile, we rather want to optimize for binary size than for performance,
294
+ // so let's not inline the kernel into the wrapper but use
295
+ // makeFromUnboxedRuntimeFunction instead.
296
+ using FuncType =
297
+ typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
298
+ return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
299
+ #endif
300
+ }
301
+
302
+ template <bool AllowLegacyTypes, class Lambda>
303
+ inline std::enable_if_t<
304
+ !guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
305
+ KernelFunction>
306
+ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
307
+ static_assert(
308
+ guts::is_functor<std::decay_t<Lambda>>::value,
309
+ "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
310
+
311
+ return makeFromUnboxedFunctor<
312
+ AllowLegacyTypes,
313
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
314
+ detail::make_unique_base<
315
+ OperatorKernel,
316
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
317
+ std::forward<Lambda>(lambda)));
318
+ }
319
+
320
+ } // namespace c10
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/util/intrusive_ptr.h>
3
+
4
+ namespace c10 {
5
+
6
+ /**
7
+ * Inherit from OperatorKernel to implement a c10 kernel.
8
+ *
9
+ * Example:
10
+ * > namespace {
11
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
12
+ * > public:
13
+ * > Tensor operator()(Tensor a, Tensor b) {...}
14
+ * > };
15
+ * > }
16
+ *
17
+ * The kernel class is allowed to have members but these are equivalent
18
+ * to global variables. The kernel implementation is responsible for
19
+ * preventing race conditions on them.
20
+ *
21
+ * See below for how to register this kernel with PyTorch.
22
+ */
23
+ struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target {
24
+ ~OperatorKernel() override = default;
25
+ };
26
+
27
+ } // namespace c10
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/CompileTimeFunctionPointer.h>
4
+
5
+ namespace c10::impl {
6
+ namespace detail {
7
+ template <class FuncPtr, class ReturnType, class ParameterList>
8
+ class WrapFunctionIntoFunctor_ {};
9
+ template <class FuncPtr, class ReturnType, class... Parameters>
10
+ class WrapFunctionIntoFunctor_<
11
+ FuncPtr,
12
+ ReturnType,
13
+ guts::typelist::typelist<Parameters...>>
14
+ final : public c10::OperatorKernel {
15
+ public:
16
+ C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
17
+ return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
18
+ }
19
+ };
20
+ } // namespace detail
21
+
22
+ // WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel
23
+ // functor. Since it is a compile time function pointer, many compilers can
24
+ // inline it into the wrapper and you don't get any performance overhead for
25
+ // wrapping.
26
+ template <class FuncPtr>
27
+ struct WrapFunctionIntoFunctor final {
28
+ static_assert(
29
+ c10::is_compile_time_function_pointer<FuncPtr>::value,
30
+ "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
31
+ using type = detail::WrapFunctionIntoFunctor_<
32
+ FuncPtr,
33
+ typename guts::function_traits<typename FuncPtr::FuncType>::return_type,
34
+ typename guts::function_traits<
35
+ typename FuncPtr::FuncType>::parameter_types>;
36
+ };
37
+
38
+ } // namespace c10::impl
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/TypeTraits.h>
4
+
5
+ namespace c10::impl {
6
+
7
+ namespace detail {
8
+ template <class FuncType, class ReturnType, class ParameterList>
9
+ class WrapFunctionIntoRuntimeFunctor_ {};
10
+ template <class FuncType, class ReturnType, class... Parameters>
11
+ class WrapFunctionIntoRuntimeFunctor_<
12
+ FuncType,
13
+ ReturnType,
14
+ guts::typelist::typelist<Parameters...>>
15
+ final : public c10::OperatorKernel {
16
+ public:
17
+ template <class FuncType_>
18
+ explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
19
+ : kernel_func_(std::forward<FuncType_>(kernel_func)) {}
20
+
21
+ decltype(auto) operator()(Parameters... args) {
22
+ return kernel_func_(std::forward<Parameters>(args)...);
23
+ }
24
+
25
+ private:
26
+ FuncType kernel_func_;
27
+ };
28
+ } // namespace detail
29
+
30
+ // WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
31
+ // inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
32
+ // This can, for example, be used for lambdas, functors or even function
33
+ // pointers. In the case of function pointers, since it is a runtime function
34
+ // pointer, there is an overhead for calling it whenever the kernel is invoked.
35
+ template <class FuncType>
36
+ using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
37
+ FuncType,
38
+ typename guts::infer_function_traits_t<FuncType>::return_type,
39
+ typename guts::infer_function_traits_t<FuncType>::parameter_types>;
40
+
41
+ } // namespace c10::impl
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // This file contains boxing (not unboxing) logic,
4
+ // i.e. how to make a vector<IValue> from a set of concrete arguments.
5
+
6
+ #include <ATen/core/ivalue.h>
7
+ #include <ATen/core/stack.h>
8
+ #include <c10/core/TensorOptions.h>
9
+
10
+ #include <ATen/core/boxing/BoxedKernel.h>
11
+
12
+ #include <c10/util/Metaprogramming.h>
13
+ #include <type_traits>
14
+
15
+ namespace c10::impl {
16
+
17
+ //
18
+ // utils
19
+ //
20
+
21
+ // is_mutable_tensor_ref
22
+ template <class T>
23
+ struct is_mutable_tensor_ref : std::false_type {};
24
+ template <>
25
+ struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {};
26
+
27
+ // is_tuple_of_mutable_tensor_refs
28
+ //
29
+ template <class T, class Enable = void>
30
+ struct is_tuple_of_mutable_tensor_refs : std::false_type {};
31
+
32
+ template <class T>
33
+ struct is_tuple_of_mutable_tensor_refs<
34
+ T,
35
+ std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>>
36
+ : guts::typelist::
37
+ all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>> {};
38
+
39
+ // has_ivalue_to<T> tests the presence/absence of instance method
40
+ // IValue::to<T>()
41
+ //
42
+ template <class T, class Enable = void>
43
+ struct has_ivalue_to : std::false_type {};
44
+
45
+ template <class T>
46
+ struct ivalue_to_helper {
47
+ using type = decltype(std::declval<IValue>().template to<T>());
48
+ };
49
+ template <class T>
50
+ using ivalue_to_helper_t = typename ivalue_to_helper<T>::type;
51
+
52
+ template <class T>
53
+ struct has_ivalue_to<T, std::void_t<ivalue_to_helper_t<T>>> : std::true_type {};
54
+
55
+ //
56
+ // boxing predicates
57
+ //
58
+
59
+ // A boxable arg type is one that IValue has a constructor for.
60
+ template <typename T>
61
+ using can_box = std::disjunction<
62
+ std::is_constructible<IValue, std::decay_t<T>>,
63
+ // TensorOptions are not directly constructible into IValue,
64
+ // but torch::jit::push knows how to handle them
65
+ std::is_same<TensorOptions, std::decay_t<T>>>;
66
+
67
+ template <typename... Ts>
68
+ using can_box_all = std::conjunction<can_box<Ts>...>;
69
+
70
+ // an unboxable result is one that can be extracted from an IValue
71
+ template <typename T>
72
+ using can_unbox = std::conjunction<
73
+ std::disjunction<
74
+ has_ivalue_to<T>,
75
+ // void returns are ok
76
+ std::is_same<void, T>>,
77
+ std::negation<std::is_lvalue_reference<T>>>;
78
+
79
+ //
80
+ // boxArgs - utility for pushing unboxed args onto IValue stack
81
+ //
82
+ template <class... Args>
83
+ torch::jit::Stack boxArgs(Args... args) {
84
+ // TODO Reuse stack vector instead of allocating?
85
+ torch::jit::Stack stack;
86
+ stack.reserve(sizeof...(Args));
87
+ torch::jit::push(stack, std::forward<Args>(args)...);
88
+ return stack;
89
+ }
90
+
91
+ template <class T>
92
+ inline constexpr size_t boxed_size_one() {
93
+ static_assert(
94
+ !std::is_same_v<std::decay_t<T>, c10::TensorOptions>,
95
+ "need to patch this path to support TensorOptions passed by reference");
96
+ return 1;
97
+ }
98
+
99
+ // torch::jit::push pushes 4 values for a TensorOptions; this needs to
100
+ // be kept in sync.
101
+ template <>
102
+ inline constexpr size_t boxed_size_one<c10::TensorOptions>() {
103
+ return 4;
104
+ }
105
+
106
+ // NOTE: this could probably be simplified with C++17 fold expressions.
107
+ template <typename...>
108
+ struct BoxedSize : std::integral_constant<size_t, 0> {};
109
+ template <class T, class... Args>
110
+ struct BoxedSize<T, Args...>
111
+ : std::integral_constant<
112
+ size_t,
113
+ boxed_size_one<T>() + BoxedSize<Args...>::value> {};
114
+
115
+ template <class... Args>
116
+ static inline constexpr size_t boxed_size() {
117
+ return BoxedSize<Args...>::value;
118
+ }
119
+
120
+ template <typename T>
121
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) {
122
+ new (dest++) IValue(arg);
123
+ }
124
+
125
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(
126
+ IValue*& dest,
127
+ c10::TensorOptions options) {
128
+ new (dest++) IValue(c10::typeMetaToScalarType(options.dtype()));
129
+ new (dest++) IValue(options.layout());
130
+ new (dest++) IValue(options.device());
131
+ new (dest++) IValue(options.pinned_memory());
132
+ }
133
+
134
+ inline void boxArgsToStack(IValue*&) {}
135
+
136
+ template <typename T, typename... Args>
137
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(
138
+ IValue*& dest,
139
+ T& arg,
140
+ Args&... args) {
141
+ boxToStack(dest, arg);
142
+ boxArgsToStack(dest, args...);
143
+ }
144
+
145
+ //
146
+ // PopResult is a helper class whose specializations handle popping single and
147
+ // multiple return values, respectively.
148
+ //
149
+ template <class Result>
150
+ struct PopResult final {
151
+ static Result call(Stack& stack) {
152
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
153
+ stack.size() == 1,
154
+ "Boxed kernel was expected to return one value on the stack, ",
155
+ "but instead pushed ",
156
+ stack.size(),
157
+ " values.");
158
+ return std::move(stack[0]).to<Result>();
159
+ }
160
+ };
161
+
162
+ template <class... Types>
163
+ struct PopResult<std::tuple<Types...>> final {
164
+ using Result = std::tuple<Types...>;
165
+
166
+ static Result call(Stack& stack) {
167
+ // for tuple return types, boxed kernel has pushed multiple values onto the
168
+ // stack
169
+ constexpr int RetCount = sizeof...(Types);
170
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
171
+ stack.size() == RetCount,
172
+ "Boxed kernel was expected to return ",
173
+ RetCount,
174
+ " values on the stack, ",
175
+ "but instead pushed ",
176
+ stack.size(),
177
+ " values.");
178
+ return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>());
179
+ }
180
+
181
+ private:
182
+ // note: this has been moved into its own helper only to avoid a parse error
183
+ // on `indices` otherwise. I'm sure there's an incantation that slips it past
184
+ // the parser but eh
185
+ template <size_t... indices>
186
+ static Result pop_to_tuple_impl(
187
+ Stack& stack,
188
+ std::index_sequence<indices...>) {
189
+ return std::make_tuple((std::move(stack[indices]).template to<Types>())...);
190
+ }
191
+ };
192
+
193
+ //
194
+ // BoxedKernelWrapper
195
+ //
196
+ // For a given function type FT, BoxedKernelWrapper<FT> implements
197
+ // a `call` method that
198
+ // - takes a boxed kernel and unboxed arguments as specified by FT,
199
+ // - calls `boxArgs` to box the arguments
200
+ // - calls the boxed kernel
201
+ // - unboxes and returns the result
202
+ //
203
+ // The partial specializations below handle various cases: in
204
+ // particular, not all types appearing in op signatures are supported,
205
+ // and ops returning references have nonstandard wrapper implementations.
206
+ //
207
+
208
+ // 1. The base specialization of BoxedKernelWrapper should never be
209
+ // instantiated. A "no call method defined on BoxedKernelWrapper" compile error
210
+ // means that an op signature has failed to trigger any of the partial
211
+ // specializations that follow this one.
212
+ //
213
+ template <class FuncType, class Enable = void>
214
+ struct BoxedKernelWrapper {
215
+ // The reason we're not just doing straight up static_assert(false, ...) here:
216
+ // Basically, the way to make sure a static_assert only fires if a template
217
+ // is actually instantiated (rather than every time the file is parsed) is to
218
+ // use template parameters in the expression, e.g. FuncType here. However,
219
+ // since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the
220
+ // same effect.
221
+ static_assert(
222
+ sizeof(FuncType) != sizeof(FuncType),
223
+ "Function signature contains one or more unsupported parameter and/or return types. "
224
+ "Look for a nearby error like "
225
+ "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
226
+ "- (your function type) is the unsupported signature.");
227
+ };
228
+
229
+ //
230
+ // 2. Supported signatures, other than those involving non-const Tensor refs -
231
+ // i.e., "functional" ops.
232
+ //
233
+
234
+ template <class Result, class... Args>
235
+ struct BoxedKernelWrapper<
236
+ Result(Args...),
237
+ std::enable_if_t<
238
+ can_box_all<Args...>::value && can_unbox<Result>::value &&
239
+ !is_tuple_of_mutable_tensor_refs<Result>::value,
240
+ void>> {
241
+ static Result call(
242
+ const BoxedKernel& boxed_kernel_func,
243
+ const OperatorHandle& opHandle,
244
+ DispatchKeySet dispatchKeySet,
245
+ Args... args) {
246
+ torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
247
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
248
+
249
+ if constexpr (!std::is_same_v<void, Result>) {
250
+ // op has pushed one or more values onto the stack.
251
+ return PopResult<Result>::call(stack);
252
+ } else {
253
+ // op returns void, boxed kernel has pushed nothing onto stack.
254
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
255
+ stack.empty(),
256
+ "Boxed kernel was expected to return no values on the stack, ",
257
+ "but instead returned ",
258
+ stack.size(),
259
+ " values.");
260
+ }
261
+ }
262
+ };
263
+
264
+ //
265
+ // 3. in-place ops take a single non-const Tensor reference
266
+ // as their first argument, and return it.
267
+ //
268
+ // Note: all signatures matching this pattern are assumed to be for such ops.
269
+ // Because of this, the generated BoxedKernelWrapper specializations simply
270
+ // return the in-place argument.
271
+ //
272
+
273
+ template <class... OtherArgs>
274
+ struct BoxedKernelWrapper<
275
+ at::Tensor&(at::Tensor&, OtherArgs...),
276
+ std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
277
+ static at::Tensor& call(
278
+ const BoxedKernel& boxed_kernel_func,
279
+ const OperatorHandle& opHandle,
280
+ DispatchKeySet dispatchKeySet,
281
+ at::Tensor& outArg,
282
+ OtherArgs... otherArgs) {
283
+ torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(
284
+ outArg, std::forward<OtherArgs>(otherArgs)...);
285
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
286
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
287
+ stack.size() == 1,
288
+ "Boxed kernel was expected to return a single value on the stack, ",
289
+ "but instead returned ",
290
+ stack.size(),
291
+ " values.");
292
+
293
+ return outArg;
294
+ }
295
+ };
296
+
297
+ //
298
+ // 3.5. In-process migration to make in-place ops take and return
299
+ // const references instead.
300
+ template <class... OtherArgs>
301
+ struct BoxedKernelWrapper<
302
+ const at::Tensor&(const at::Tensor&, OtherArgs...),
303
+ std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
304
+ static const at::Tensor& call(
305
+ const BoxedKernel& boxed_kernel_func,
306
+ const OperatorHandle& opHandle,
307
+ DispatchKeySet dispatchKeySet,
308
+ const at::Tensor& outArg,
309
+ OtherArgs... otherArgs) {
310
+ torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
311
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
312
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
313
+ stack.size() == 1,
314
+ "Boxed kernel was expected to return a single value on the stack, ",
315
+ "but instead returned ",
316
+ stack.size(),
317
+ " values.");
318
+
319
+ return outArg;
320
+ }
321
+ };
322
+
323
+ //
324
+ // 4. out of place ops that take a single non-const Tensor reference as their
325
+ // final argument, and also return it.
326
+ //
327
+ // Note: all signatures matching this pattern are assumed to be for such ops.
328
+ // This assumption permits the generated BoxedKernelWrapper specializations to
329
+ // simply return out arguments.
330
+ //
331
+ template <class FirstArg, class... RestArgs>
332
+ struct BoxedKernelWrapper<
333
+ at::Tensor&(FirstArg, RestArgs...),
334
+ std::enable_if_t<
335
+ can_box_all<FirstArg, RestArgs...>::value
336
+ // this skips over in-place kernels with a non-const Tensor
337
+ // arg at the front, so those can unambiguously trigger the
338
+ // preceding specialization.
339
+ && !is_mutable_tensor_ref<FirstArg>::value,
340
+ void>> {
341
+ static at::Tensor& call(
342
+ const BoxedKernel& boxed_kernel_func,
343
+ const OperatorHandle& opHandle,
344
+ DispatchKeySet dispatchKeySet,
345
+ FirstArg firstArg,
346
+ RestArgs... restArgs) {
347
+ torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(
348
+ std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...);
349
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
350
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
351
+ stack.size() == 1,
352
+ "Boxed kernel was expected to return a single value on the stack, ",
353
+ "but instead returned ",
354
+ stack.size(),
355
+ " values.");
356
+
357
+ // reusing restArgs after it has been forwarded here is ok because we know
358
+ // that the last element is of type `Tensor&`.
359
+ return std::get<sizeof...(RestArgs) - 1>(
360
+ std::tuple<RestArgs...>{restArgs...});
361
+ }
362
+ };
363
+
364
+ //
365
+ // 5. out of place ops that take multiple non-const Tensor references as their
366
+ // final arguments, and return them in a std::tuple.
367
+ //
368
+ // Note: all signatures matching this pattern are assumed to be for such ops.
369
+ // This assumption permits the generated BoxedKernelWrapper specializations to
370
+ // simply return the out arguments.
371
+ //
372
+ template <class Result, class... Args>
373
+ struct BoxedKernelWrapper<
374
+ Result(Args...),
375
+ std::enable_if_t<
376
+ can_box_all<Args...>::value &&
377
+ is_tuple_of_mutable_tensor_refs<Result>::value,
378
+ void>> {
379
+ static Result call(
380
+ const BoxedKernel& boxed_kernel_func,
381
+ const OperatorHandle& opHandle,
382
+ DispatchKeySet dispatchKeySet,
383
+ Args... args) {
384
+ using ArgTuple = std::tuple<Args...>;
385
+ constexpr int RetCount = std::tuple_size<Result>();
386
+
387
+ torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
388
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
389
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
390
+ stack.size() == RetCount,
391
+ "Boxed kernel was expected to return ",
392
+ RetCount,
393
+ " values on the stack, ",
394
+ "but instead returned ",
395
+ stack.size(),
396
+ " values.");
397
+
398
+ // reusing args after it has been forwarded here is ok because we know
399
+ // that the last RetCount elements are of type `Tensor&`.
400
+ auto result = guts::tuple_take<ArgTuple, -RetCount>(
401
+ ArgTuple{std::forward<Args>(args)...});
402
+ static_assert(
403
+ std::is_same_v<Result, decltype(result)>,
404
+ "The parameter list of an op returning a tuple of Tensor references "
405
+ "must end with an equal number of Tensor reference parameters.");
406
+ return result;
407
+ }
408
+ };
409
+
410
+ } // namespace c10::impl
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/IListRef.h>
4
+ #include <ATen/core/boxing/OperatorKernel.h>
5
+ #include <ATen/core/ivalue.h>
6
+ #include <ATen/core/stack.h>
7
+ #include <c10/util/Metaprogramming.h>
8
+ #include <c10/util/TypeList.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+
11
+ #include <utility>
12
+
13
+ namespace c10 {
14
+
15
+ using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
16
+ // to the c10 namespace.
17
+ class OperatorHandle;
18
+
19
+ /*
20
+ * [Note: Argument forwarding in the dispatcher]
21
+ *
22
+ * The dispatcher uses a somewhat unusual way to forward arguments through
23
+ * several layers of wrapper functions. This can be confusing because an
24
+ * experienced C++ programmer would look at this and think "oh this is supposed
25
+ * to be forwarding a universal reference but the && is missing. This is a
26
+ * bug.". It is not a bug. The common way in C++ to forward arguments is to use
27
+ * universal references:
28
+ *
29
+ * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
30
+ *
31
+ * but that relies on inferring the correct reference type (i.e. value vs & vs
32
+ * &&) from the argument. In our case, we cannot rely on the argument as
33
+ * supplied by the caller, because that could infer a different reference type
34
+ * than was used in the kernel function. The correct reference type is dictated
35
+ * by the kernel signature and must be identical since we cast function pointers
36
+ * through void* pointers and mismatches would be UB. So we need a forwarding
37
+ * pattern that determines the reference type to use by looking at the
38
+ * explicitly supplied operator signature, not by looking at the argument we're
39
+ * calling it with.
40
+ *
41
+ * What does std::forward do, exactly?
42
+ * ------------------------------------
43
+ * std::forward<T>(t) is a way to cast t to the reference type supplied in T.
44
+ * Let's assume decay_t<T> == U and T is either U or some reference of U.
45
+ * - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
46
+ * - std::forward<T&&>(t) will return U&&, no matter what kind of reference t
47
+ * is.
48
+ * - std::forward<T>(t) will return U&& (not U!), no matter what kind of
49
+ * reference t is.
50
+ *
51
+ * For universal references, that means that in the following function
52
+ * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
53
+ *
54
+ * - when called with arg being a rvalue reference or non-reference value, T
55
+ * gets inferred to be a non-reference U, and std::forward<T>(t) will return
56
+ * U&&, correctly moving the argument.
57
+ * - when called with arg behind a lvalue reference, T gets inferred to be U&
58
+ * because that's the only way to match the signature (in C++, a type that is
59
+ * (T&)&& will collapse to T&). That means std::forward<T>(t) will return U& and
60
+ * the value will not be moved but passed on as a lvalue reference.
61
+ *
62
+ * How do we use that?
63
+ * ------------------------------------
64
+ * But std::forward can also be used outside of the common "universal
65
+ * forwarding" pattern to change reference types. So instead of following the
66
+ * common C++ pattern, we notice what std::forward<T>() actually does, and that
67
+ * is it takes a value and changes its reference to the type of reference passed
68
+ * in as T. If we don't infer T but explicitly specify it, we can use this to
69
+ * forward based on an explicitly specified reference type instead of the
70
+ * inferred argument type.
71
+ *
72
+ * This is why many of the dispatcher functions look like
73
+ * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
74
+ * instead of the common
75
+ * > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
76
+ *
77
+ * and are expected to be called by explicitly specifying the template
78
+ * parameters in a way that matches the expected operator signature at each call
79
+ * site.
80
+ */
81
+
82
+ namespace impl {
83
+ // supported_primitive_arg_types defines which primitive types we allow in
84
+ // kernel functions as arguments or returns.
85
+ // Additionally, we support lists, dicts and optionals containing these types.
86
+ using supported_primitive_arg_types = guts::typelist::typelist<
87
+ int64_t,
88
+ double,
89
+ bool,
90
+ std::string_view,
91
+ at::Tensor,
92
+ at::Scalar,
93
+ c10::QScheme,
94
+ c10::ScalarType,
95
+ c10::Device,
96
+ c10::DeviceIndex,
97
+ c10::Layout,
98
+ c10::MemoryFormat,
99
+ at::Dimname>;
100
+
101
+ // We have an unboxed functor in hand that takes C++ arguments, and
102
+ // we're building a boxed functor wrapper for it that takes IValues.
103
+ // So "outside" is boxed and "inside" is unboxed.
104
+ //
105
+ // So a valid input type is one that our boxed functor wrapper can
106
+ // unbox from an IValue into a C++ value.
107
+ //
108
+ // Whereas a valid output type is one that our wrapper can recieve
109
+ // as a C++ value from the unboxed functor, and box into an IValue.
110
+
111
+ //
112
+ // assert_is_valid_input_type
113
+ // checks that T can be unboxed from an IValue into a C++ value.
114
+ //
115
+
116
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
117
+ struct assert_is_valid_input_type {
118
+ assert_is_valid_input_type() {
119
+ if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
120
+ value) {
121
+ /* everything is ok, this is a primitive type */
122
+ } else {
123
+ /* otherwise this must be an instance of a valid custom class, since it
124
+ can only have been created via IValue(x), which ensures this. */
125
+ }
126
+ }
127
+ };
128
+
129
+ template <class T, bool AllowDeprecatedTypes>
130
+ struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes>
131
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
132
+
133
+ template <bool AllowDeprecatedTypes, class... Args>
134
+ struct TypeCheckHelper;
135
+
136
+ template <bool AllowDeprecatedTypes>
137
+ struct TypeCheckHelper<AllowDeprecatedTypes> {};
138
+
139
+ template <bool AllowDeprecatedTypes, class Head, class... Rest>
140
+ struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
141
+ : TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
142
+ assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
143
+ };
144
+
145
+ template <class... Contained, bool AllowDeprecatedTypes>
146
+ struct assert_is_valid_input_type<
147
+ std::tuple<Contained...>,
148
+ AllowDeprecatedTypes>
149
+ : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
150
+
151
+ template <class Key, class Value, bool AllowDeprecatedTypes>
152
+ struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
153
+ : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
154
+ static_assert(
155
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
156
+ "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
157
+ };
158
+
159
+ template <class Key, class Value, bool AllowDeprecatedTypes>
160
+ struct assert_is_valid_input_type<
161
+ std::unordered_map<Key, Value>,
162
+ AllowDeprecatedTypes>
163
+ : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
164
+ static_assert(
165
+ AllowDeprecatedTypes,
166
+ "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
167
+ static_assert(
168
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
169
+ "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
170
+ };
171
+
172
+ template <class T, bool AllowDeprecatedTypes>
173
+ struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
174
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
175
+ static_assert(
176
+ !std::is_same_v<T, at::Scalar>,
177
+ "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
178
+ };
179
+
180
+ template <class T, bool AllowDeprecatedTypes>
181
+ struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
182
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
183
+ static_assert(
184
+ !std::is_same_v<T, at::Scalar>,
185
+ "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
186
+ };
187
+
188
+ template <class T, bool AllowDeprecatedTypes>
189
+ struct assert_is_valid_input_type<
190
+ c10::OptionalArrayRef<T>,
191
+ AllowDeprecatedTypes>
192
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
193
+ static_assert(
194
+ !std::is_same_v<T, at::Scalar>,
195
+ "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
196
+ };
197
+
198
+ template <class T, size_t N, bool AllowDeprecatedTypes>
199
+ struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
200
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
201
+ static_assert(
202
+ !std::is_same_v<T, at::Scalar>,
203
+ "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
204
+ };
205
+
206
+ template <class T, bool AllowDeprecatedTypes>
207
+ struct assert_is_valid_input_type<
208
+ T,
209
+ AllowDeprecatedTypes,
210
+ std::enable_if_t<std::is_same_v<float, T>>> {
211
+ // There is no reason to support float when we have double. Keep the API lean.
212
+ static_assert(
213
+ guts::false_t<T>::value,
214
+ "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
215
+ };
216
+ template <class T, bool AllowDeprecatedTypes>
217
+ struct assert_is_valid_input_type<
218
+ T,
219
+ AllowDeprecatedTypes,
220
+ std::enable_if_t<std::is_same_v<const char*, T>>> {
221
+ static_assert(
222
+ guts::false_t<T>::value,
223
+ "You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead.");
224
+ };
225
+ template <class T, bool AllowDeprecatedTypes>
226
+ struct assert_is_valid_input_type<
227
+ T,
228
+ AllowDeprecatedTypes,
229
+ std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
230
+ static_assert(
231
+ guts::false_t<T>::value,
232
+ "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
233
+ };
234
+ template <class T, bool AllowDeprecatedTypes>
235
+ struct assert_is_valid_input_type<
236
+ T,
237
+ AllowDeprecatedTypes,
238
+ std::enable_if_t<
239
+ std::is_integral_v<T> &&
240
+ !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
241
+ static_assert(
242
+ guts::false_t<T>::value,
243
+ "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
244
+ };
245
+ template <class T, bool AllowDeprecatedTypes>
246
+ struct assert_is_valid_input_type<
247
+ T,
248
+ AllowDeprecatedTypes,
249
+ std::enable_if_t<std::is_same_v<const c10::SymInt&, T>>> {
250
+ static_assert(
251
+ guts::false_t<T>::value,
252
+ "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
253
+ };
254
+
255
+ // TODO: it probably would be good to tighten this up quite a bit more with
256
+ // an explicit list for everything
257
+
258
+ //
259
+ // assert_is_valid_output_type
260
+ //
261
+
262
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
263
+ struct assert_is_valid_output_type {
264
+ assert_is_valid_output_type() {
265
+ if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
266
+ value) {
267
+ /* everything is ok, this is a primitive type */
268
+ } else {
269
+ /* otherwise T is verified to be a registered custom class in the IValue
270
+ constructor, so no benefit in double-checking here */
271
+ }
272
+ }
273
+ };
274
+
275
+ template <class T, bool AllowDeprecatedTypes>
276
+ struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes>
277
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
278
+
279
+ template <class T, bool AllowDeprecatedTypes>
280
+ struct assert_is_valid_output_type<
281
+ c10::OptionalArrayRef<T>,
282
+ AllowDeprecatedTypes>
283
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
284
+
285
+ template <class Key, class Value, bool AllowDeprecatedTypes>
286
+ struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
287
+ : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
288
+ static_assert(
289
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
290
+ "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
291
+ static_assert(
292
+ !std::is_same_v<Value, at::Scalar>,
293
+ "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
294
+ };
295
+
296
+ template <class Key, class Value, bool AllowDeprecatedTypes>
297
+ struct assert_is_valid_output_type<
298
+ std::unordered_map<Key, Value>,
299
+ AllowDeprecatedTypes>
300
+ : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
301
+ static_assert(
302
+ AllowDeprecatedTypes,
303
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
304
+ static_assert(
305
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
306
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
307
+ static_assert(
308
+ !std::is_same_v<Value, at::Scalar>,
309
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
310
+ };
311
+
312
+ template <class T, bool AllowDeprecatedTypes>
313
+ struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
314
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
315
+ static_assert(
316
+ !std::is_same_v<T, at::Scalar>,
317
+ "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
318
+ };
319
+
320
+ template <class T, bool AllowDeprecatedTypes>
321
+ struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
322
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
323
+ static_assert(
324
+ !std::is_same_v<T, at::Scalar>,
325
+ "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
326
+ // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel
327
+ // with an unsupported output type: std::vector<T>. Please use List<T>
328
+ // instead.");
329
+ };
330
+
331
+ template <class T, size_t N, bool AllowDeprecatedTypes>
332
+ struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
333
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
334
+ static_assert(
335
+ !std::is_same_v<T, at::Scalar>,
336
+ "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
337
+ };
338
+
339
+ // The following specialisations of assert_is_valid_output_type are technically
340
+ // not necessary since we would hit the base case and show an error message
341
+ // there if they didn't exist, but we can show a better error message
342
+ // in some common error scenarios.
343
+ template <class T, bool AllowDeprecatedTypes>
344
+ struct assert_is_valid_output_type<
345
+ T,
346
+ AllowDeprecatedTypes,
347
+ std::enable_if_t<std::is_same_v<float, T>>> {
348
+ // There is no reason to support float when we have double. Keep the API lean.
349
+ static_assert(
350
+ guts::false_t<T>::value,
351
+ "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
352
+ };
353
+ template <class T, bool AllowDeprecatedTypes>
354
+ struct assert_is_valid_output_type<
355
+ T,
356
+ AllowDeprecatedTypes,
357
+ std::enable_if_t<std::is_same_v<const char*, T>>> {
358
+ static_assert(
359
+ guts::false_t<T>::value,
360
+ "You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead.");
361
+ };
362
+ template <class T, bool AllowDeprecatedTypes>
363
+ struct assert_is_valid_output_type<
364
+ T,
365
+ AllowDeprecatedTypes,
366
+ std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
367
+ static_assert(
368
+ guts::false_t<T>::value,
369
+ "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
370
+ };
371
+ template <class T, bool AllowDeprecatedTypes>
372
+ struct assert_is_valid_output_type<
373
+ T,
374
+ AllowDeprecatedTypes,
375
+ std::enable_if_t<
376
+ std::is_integral_v<T> &&
377
+ !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
378
+ static_assert(
379
+ guts::false_t<T>::value,
380
+ "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
381
+ };
382
+
383
+ // ivalue_to_arg
384
+
385
+ template <class T>
386
+ struct decay_if_not_tensor final {
387
+ using type = std::decay_t<T>;
388
+ };
389
+
390
+ template <>
391
+ struct decay_if_not_tensor<at::Tensor&> final {
392
+ using type = at::Tensor&;
393
+ };
394
+
395
+ template <>
396
+ struct decay_if_not_tensor<const at::Tensor&> final {
397
+ using type = const at::Tensor&;
398
+ };
399
+
400
+ template <class T, bool AllowDeprecatedTypes>
401
+ struct ivalue_to_arg final {
402
+ static decltype(auto) call(IValue& v) {
403
+ assert_is_valid_input_type<T, AllowDeprecatedTypes>();
404
+ return std::move(v).to<T>();
405
+ }
406
+ };
407
+
408
+ // The following two specializations take advantage of specialized
409
+ // `toTensor()` overloads on IValue to avoid copying.
410
+ template <bool AllowDeprecatedTypes>
411
+ struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
412
+ // We cannot use the default implementation if they asked for a
413
+ // `at::Tensor&` because it moves from the IValue, so it can't get
414
+ // an lvalue reference.
415
+ static at::Tensor& call(IValue& v) {
416
+ // Tensor& is valid, don't bother asserting
417
+ return v.toTensor();
418
+ }
419
+ };
420
+
421
+ template <bool AllowDeprecatedTypes>
422
+ struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
423
+ // We should not use the default implementation if they asked for
424
+ // a `const at::Tensor&` because it moves from the IValue and they
425
+ // didn't ask for that.
426
+ static const at::Tensor& call(IValue& v) {
427
+ // const Tensor& is valid, don't bother asserting
428
+ return v.toTensor();
429
+ }
430
+ };
431
+
432
+ template <bool AllowDeprecatedTypes>
433
+ struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
434
+ static List<at::Tensor> call(IValue& v) {
435
+ return v.toTensorList();
436
+ }
437
+ };
438
+
439
+ template <class T, bool AllowDeprecatedTypes>
440
+ struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
441
+ // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and
442
+ // pass that to the operator. std::vector<T> is implicitly convertible to
443
+ // ArrayRef<T>.
444
+ static std::vector<T> call(IValue& v) {
445
+ return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
446
+ }
447
+ };
448
+ template <bool AllowDeprecatedTypes>
449
+ struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
450
+ static std::vector<c10::SymInt> call(IValue& v) {
451
+ if (v.isIntList()) {
452
+ std::vector<c10::SymInt> r;
453
+ auto src = v.toIntList();
454
+ std::transform(
455
+ src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
456
+ return c10::SymInt(i);
457
+ });
458
+ return r;
459
+ } else {
460
+ return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::
461
+ call(v);
462
+ }
463
+ }
464
+ };
465
+ template <bool AllowDeprecatedTypes>
466
+ struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes>
467
+ final {
468
+ static OptionalArray<c10::SymInt> call(IValue& v) {
469
+ if (v.isIntList()) {
470
+ std::vector<c10::SymInt> r;
471
+ auto src = v.toIntList();
472
+ std::transform(
473
+ src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
474
+ return c10::SymInt(i);
475
+ });
476
+ return OptionalArray<c10::SymInt>(std::move(r));
477
+ } else {
478
+ return std::move(v).to<OptionalArray<c10::SymInt>>();
479
+ }
480
+ }
481
+ };
482
+ template <class T, bool AllowDeprecatedTypes>
483
+ struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
484
+ // If an argument is std::optional<ArrayRef<T>>, convert the IValue to an
485
+ // std::optional<std::vector<T>> and pass that to the operator.
486
+ // OptionalArray<T> is basically a std::optional<std::vector<T>> but
487
+ // implicitly convertible to std::optional<ArrayRef<T>>.
488
+ static OptionalArray<T> call(IValue& v) {
489
+ return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
490
+ }
491
+ };
492
+
493
+ template <class T, bool AllowDeprecatedTypes>
494
+ struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
495
+ // If an argument is OptionalArrayRef<T>, convert the IValue to an
496
+ // std::optional<std::vector<T>> and pass that to the operator.
497
+ // OptionalArray<T> is basically a std::optional<std::vector<T>> but
498
+ // implicitly convertible to OptionalArrayRef<T>
499
+ static OptionalArray<T> call(IValue& v) {
500
+ return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
501
+ }
502
+ };
503
+
504
+ // return_to_ivalue
505
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
506
+ struct return_to_ivalue final {};
507
+
508
+ template <class T, bool AllowDeprecatedTypes>
509
+ struct return_to_ivalue<
510
+ T,
511
+ AllowDeprecatedTypes,
512
+ std::enable_if_t<!std::is_same_v<at::Tensor&, T>>>
513
+ final {
514
+ static IValue call(T&& v) {
515
+ assert_is_valid_output_type<T, AllowDeprecatedTypes>();
516
+ return c10::ivalue::from(std::move(v));
517
+ }
518
+ static IValue copy(const T& v) {
519
+ assert_is_valid_output_type<T, AllowDeprecatedTypes>();
520
+ return IValue(v);
521
+ }
522
+ };
523
+
524
+ // Special case to allow kernels to return `Tensor&`.
525
+ // TODO Delete this once kernels don't do that anymore
526
+ template <bool AllowDeprecatedTypes>
527
+ struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
528
+ static IValue call(at::Tensor& v) {
529
+ return c10::ivalue::from(v);
530
+ }
531
+ static IValue copy(at::Tensor& v) {
532
+ return IValue(v);
533
+ }
534
+ };
535
+
536
+ // wrap_kernel_functor_unboxed_
537
+
538
+ template <class KernelFunctor, class OpSignature>
539
+ struct wrap_kernel_functor_unboxed_ final {};
540
+
541
+ // This specialization is for kernels with a first argument that is NOT of type
542
+ // DispatchKeySet This includes kernels with 0 arguments.
543
+ template <class KernelFunctor, class ReturnType, class... ParameterTypes>
544
+ struct wrap_kernel_functor_unboxed_<
545
+ KernelFunctor,
546
+ ReturnType(ParameterTypes...)>
547
+ final {
548
+ static_assert(
549
+ std::is_same_v<
550
+ ReturnType,
551
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
552
+ "Return type mismatch");
553
+ static_assert(
554
+ std::is_same_v<
555
+ guts::typelist::typelist<ParameterTypes...>,
556
+ typename guts::infer_function_traits_t<
557
+ KernelFunctor>::parameter_types>,
558
+ "Parameter types mismatch");
559
+
560
+ // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
561
+ // doesn't use &&
562
+ static ReturnType call(
563
+ OperatorKernel* functor,
564
+ DispatchKeySet,
565
+ ParameterTypes... args) {
566
+ KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
567
+ // Note [Plumbing Keys Through The Dispatcher 2]
568
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
569
+ // This functor explicitly takes in a dispatchKeySet and drops it on the
570
+ // floor- it does not forward it to the registered kernel.
571
+ //
572
+ // This is due to the calling convention within the dispatcher, which
573
+ // expects all registered kernels to have a first argument of type
574
+ // DispatchKeySet.
575
+ // This is not the case for pretty much all manually written kernels,
576
+ // however- this functor serves to separate the calling convention of the
577
+ // dispatcher from the calling convention of manually written kernels.
578
+ return (*functor_)(std::forward<ParameterTypes>(args)...);
579
+ }
580
+ };
581
+
582
+ // This specialization is for kernels with a first argument of type
583
+ // DispatchKeySet
584
+ template <class KernelFunctor, class ReturnType, class... ParameterTypes>
585
+ struct wrap_kernel_functor_unboxed_<
586
+ KernelFunctor,
587
+ ReturnType(DispatchKeySet, ParameterTypes...)>
588
+ final {
589
+ static_assert(
590
+ std::is_same_v<
591
+ ReturnType,
592
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
593
+ "Return type mismatch");
594
+ static_assert(
595
+ std::is_same_v<
596
+ guts::typelist::typelist<DispatchKeySet, ParameterTypes...>,
597
+ typename guts::infer_function_traits_t<
598
+ KernelFunctor>::parameter_types>,
599
+ "Parameter types mismatch");
600
+
601
+ // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
602
+ // doesn't use &&
603
+ static ReturnType call(
604
+ OperatorKernel* functor,
605
+ DispatchKeySet dispatchKeySet,
606
+ ParameterTypes... args) {
607
+ KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
608
+ // We're explicitly taking in a dispatchKeySet and forwarding it to the
609
+ // registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for
610
+ // details.
611
+ return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
612
+ }
613
+ };
614
+
615
+ template <class KernelFunctor>
616
+ using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<
617
+ KernelFunctor,
618
+ typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
619
+
620
+ // call_functor_with_args_from_stack
621
+
622
+ template <
623
+ class Functor,
624
+ bool AllowDeprecatedTypes,
625
+ size_t... ivalue_arg_indices,
626
+ typename... ArgTypes>
627
+ std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
628
+ call_functor_with_args_from_stack_(
629
+ OperatorKernel* functor,
630
+ DispatchKeySet dispatchKeySet,
631
+ Stack* stack,
632
+ std::index_sequence<ivalue_arg_indices...>,
633
+ guts::typelist::typelist<ArgTypes...>*) {
634
+ (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
635
+ // be unused and we have to silence the compiler warning.
636
+
637
+ // We're explicitly filtering out DispatchKeySet from the argument list.
638
+ // Some kernels take a DispatchKeySet as their first argument in order to
639
+ // plumb keys through the dispatcher. We don't want to expose the
640
+ // DispatchKeySet type to jit, so we don't include this argument on the stack.
641
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
642
+ return wrap_kernel_functor_unboxed<Functor>::call(
643
+ functor,
644
+ dispatchKeySet,
645
+ ivalue_to_arg<
646
+ typename decay_if_not_tensor<ArgTypes>::type,
647
+ AllowDeprecatedTypes>::
648
+ call(torch::jit::peek(
649
+ *stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...);
650
+ }
651
+
652
+ template <class Functor, bool AllowDeprecatedTypes>
653
+ std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
654
+ call_functor_with_args_from_stack(
655
+ OperatorKernel* functor,
656
+ DispatchKeySet dispatchKeySet,
657
+ Stack* stack) {
658
+ // We're explicitly filtering out DispatchKeySet from the argument list.
659
+ // Some kernels take a DispatchKeySet as their first argument in order to
660
+ // plumb keys through the dispatcher. We don't want to expose the
661
+ // DispatchKeySet type to jit, so we don't include this argument on the stack.
662
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
663
+ using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
664
+ Functor>::parameter_types;
665
+ constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
666
+ return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(
667
+ functor,
668
+ dispatchKeySet,
669
+ stack,
670
+ std::make_index_sequence<num_ivalue_args>(),
671
+ static_cast<ArgTypes*>(nullptr));
672
+ }
673
+
674
+ // push_outputs
675
+
676
+ template <class OutputType, bool AllowDeprecatedTypes>
677
+ struct push_outputs final {
678
+ // Contrary to [Note: Argument forwarding in the dispatcher], we use
679
+ // OutputType&& here to avoid one extra call to the move constructor in this
680
+ // case. This is still not a universal reference though because OutputType is
681
+ // an explicitly specified class template parameter.
682
+ static void call(OutputType&& output, Stack* stack) {
683
+ torch::jit::push(
684
+ *stack,
685
+ return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(
686
+ std::forward<OutputType>(output)));
687
+ }
688
+ static void copy(const OutputType& output, Stack* stack) {
689
+ torch::jit::push(
690
+ *stack,
691
+ return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
692
+ }
693
+ };
694
+ template <class... OutputTypes, bool AllowDeprecatedTypes>
695
+ struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
696
+ static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
697
+ call_(
698
+ std::move(output),
699
+ stack,
700
+ std::make_index_sequence<sizeof...(OutputTypes)>());
701
+ }
702
+ static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
703
+ copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
704
+ }
705
+
706
+ private:
707
+ template <size_t... indices>
708
+ static void call_(
709
+ std::tuple<OutputTypes...>&& output,
710
+ Stack* stack,
711
+ std::index_sequence<indices...>) {
712
+ torch::jit::push(
713
+ *stack,
714
+ return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(
715
+ std::forward<OutputTypes>(std::get<indices>(output)))...);
716
+ }
717
+ template <size_t... indices>
718
+ static void copy_(
719
+ const std::tuple<OutputTypes...>& output,
720
+ Stack* stack,
721
+ std::index_sequence<indices...>) {
722
+ torch::jit::push(
723
+ *stack,
724
+ return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(
725
+ std::get<indices>(output))...);
726
+ }
727
+ };
728
+ template <bool AllowDeprecatedTypes>
729
+ struct push_outputs<void, AllowDeprecatedTypes> final {
730
+ static void call(int /*dummy*/, Stack* /*stack*/) {}
731
+ static void copy(int /*dummy*/, Stack* /*stack*/) {}
732
+ };
733
+
734
+ // make_boxed_from_unboxed_functor
735
+
736
+ template <class KernelFunctor, bool AllowDeprecatedTypes>
737
+ struct make_boxed_from_unboxed_functor final {
738
+ static_assert(
739
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
740
+ "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
741
+
742
+ static void call(
743
+ OperatorKernel* functor,
744
+ const OperatorHandle&,
745
+ DispatchKeySet dispatchKeySet,
746
+ Stack* stack) {
747
+ using ReturnType =
748
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type;
749
+ // We're explicitly filtering out DispatchKeySet from the argument list.
750
+ // Some kernels take a DispatchKeySet as their first argument in order to
751
+ // plumb keys through the dispatcher. We don't want to expose the
752
+ // DispatchKeySet type to jit, so we don't include this argument on the
753
+ // stack. See Note [Plumbing Keys Through The Dispatcher] for the
754
+ // background.
755
+ using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
756
+ KernelFunctor>::parameter_types;
757
+ constexpr bool has_outputs = !std::is_same_v<void, ReturnType>;
758
+ constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
759
+ if constexpr (has_outputs) {
760
+ // Decay ReturnType to ReturnType_ so that if a reference gets returned,
761
+ // we actually store it by value and don't get a dangling reference. This
762
+ // is only required because some kernels still return `Tensor&`. [Note:
763
+ // VC++ and 'std': ambiguous symbol]
764
+ using ReturnType_ = ::std::decay_t<ReturnType>;
765
+ ReturnType_ output = call_functor_with_args_from_stack<
766
+ KernelFunctor,
767
+ AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
768
+ torch::jit::drop(*stack, num_inputs);
769
+ // See note [ VC++ and 'std': ambiguous symbol]
770
+ push_outputs<ReturnType_, AllowDeprecatedTypes>::call(
771
+ ::std::move(output), stack);
772
+ } else {
773
+ call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(
774
+ functor, dispatchKeySet, stack);
775
+ torch::jit::drop(*stack, num_inputs);
776
+ }
777
+ }
778
+ };
779
+ } // namespace impl
780
+
781
+ } // namespace c10
782
+
783
+ namespace torch {
784
+ using OperatorKernel = c10::OperatorKernel;
785
+ }
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <gmock/gmock.h>
4
+ #include <gtest/gtest.h>
5
+
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/core/dispatch/Dispatcher.h>
8
+ #include <ATen/core/ivalue.h>
9
+ #include <c10/core/CPUAllocator.h>
10
+ #include <c10/util/irange.h>
11
+
12
+ template <class... Inputs>
13
+ inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
14
+ return {std::forward<Inputs>(inputs)...};
15
+ }
16
+
17
+ inline at::Tensor dummyTensor(
18
+ c10::DispatchKeySet ks,
19
+ bool requires_grad = false) {
20
+ auto* allocator = c10::GetCPUAllocator();
21
+ int64_t nelements = 1;
22
+ auto dtype = caffe2::TypeMeta::Make<float>();
23
+ int64_t size_bytes = nelements * dtype.itemsize();
24
+ auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
25
+ c10::StorageImpl::use_byte_size_t(),
26
+ size_bytes,
27
+ allocator->allocate(size_bytes),
28
+ allocator,
29
+ /*resizable=*/true);
30
+ at::Tensor t =
31
+ at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
32
+ // TODO: We add this to simulate the ideal case where we only have Autograd
33
+ // backend keys
34
+ // on Tensor when it requires grad. But currently Autograd keys are
35
+ // added in TensorImpl constructor by default.
36
+ if (!requires_grad) {
37
+ t.unsafeGetTensorImpl()->remove_autograd_key();
38
+ }
39
+ return t;
40
+ }
41
+
42
+ inline at::Tensor dummyTensor(
43
+ c10::DispatchKey dispatch_key,
44
+ bool requires_grad = false) {
45
+ return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
46
+ }
47
+
48
+ template <class... Args>
49
+ inline std::vector<c10::IValue> callOp(
50
+ const c10::OperatorHandle& op,
51
+ Args... args) {
52
+ auto stack = makeStack(std::forward<Args>(args)...);
53
+ op.callBoxed(&stack);
54
+ return stack;
55
+ }
56
+
57
+ template <class Result, class... Args>
58
+ inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
59
+ return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
60
+ }
61
+
62
+ template <class Result, class... Args>
63
+ inline Result callOpUnboxedWithDispatchKey(
64
+ const c10::OperatorHandle& op,
65
+ c10::DispatchKey dispatchKey,
66
+ Args... args) {
67
+ return op.typed<Result(Args...)>().callWithDispatchKey(
68
+ dispatchKey, std::forward<Args>(args)...);
69
+ }
70
+
71
+ template <class Result, class... Args>
72
+ inline Result callOpUnboxedWithPrecomputedDispatchKeySet(
73
+ const c10::OperatorHandle& op,
74
+ c10::DispatchKeySet ks,
75
+ Args... args) {
76
+ return op.typed<Result(Args...)>().redispatch(
77
+ ks, std::forward<Args>(args)...);
78
+ }
79
+
80
+ inline void expectDoesntFindKernel(
81
+ const char* op_name,
82
+ c10::DispatchKey dispatch_key) {
83
+ auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
84
+ EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5););
85
+ }
86
+
87
+ inline void expectDoesntFindOperator(const char* op_name) {
88
+ auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
89
+ EXPECT_FALSE(op.has_value());
90
+ }
91
+
92
+ template <class Exception, class Functor>
93
+ inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
94
+ try {
95
+ std::forward<Functor>(functor)();
96
+ } catch (const Exception& e) {
97
+ EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
98
+ return;
99
+ }
100
+ ADD_FAILURE() << "Expected to throw exception containing \""
101
+ << expectMessageContains << "\" but didn't throw";
102
+ }
103
+
104
+ template <class T, size_t N>
105
+ void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
106
+ EXPECT_EQ(expected.size(), actual.size());
107
+ for (const auto i : c10::irange(expected.size())) {
108
+ EXPECT_EQ(expected[i], actual[i]);
109
+ }
110
+ }
111
+
112
+ template <class T>
113
+ void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
114
+ EXPECT_EQ(expected.size(), actual.size());
115
+ for (const auto i : c10::irange(expected.size())) {
116
+ EXPECT_EQ(expected[i], actual[i]);
117
+ }
118
+ }
119
+
120
+ template <class T>
121
+ void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
122
+ EXPECT_EQ(expected.size(), actual.size());
123
+ for (const auto i : c10::irange(expected.size())) {
124
+ EXPECT_EQ(expected[i], actual.get(i));
125
+ }
126
+ }
127
+
128
+ template <class T>
129
+ void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
130
+ EXPECT_EQ(expected.size(), actual.size());
131
+ for (const auto i : c10::irange(expected.size())) {
132
+ EXPECT_EQ(expected[i], actual[i]);
133
+ }
134
+ }
135
+
136
+ // NB: This is not really sound, but all of the type sets constructed here
137
+ // are singletons so it's fine
138
+ static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
139
+ return legacyExtractDispatchKey(t.key_set());
140
+ }
phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/DispatchKeySet.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/Metaprogramming.h>
6
+ #include <c10/util/Type.h>
7
+ #include <typeindex>
8
+
9
+ namespace c10::impl {
10
+
11
+ // A CppSignature object holds RTTI information about a C++ function signature
12
+ // at runtime and can compare them or get a debug-printable name.
13
+ class TORCH_API CppSignature final {
14
+ public:
15
+ CppSignature(const CppSignature&) = default;
16
+ CppSignature(CppSignature&&) noexcept = default;
17
+ CppSignature& operator=(const CppSignature&) = default;
18
+ CppSignature& operator=(CppSignature&&) noexcept = default;
19
+
20
+ template <class FuncType>
21
+ static CppSignature make() {
22
+ // Normalize functors, lambdas, function pointers, etc. into the plain
23
+ // function type The first argument of the schema might be of type
24
+ // DispatchKeySet, in which case we remove it. We do this to guarantee that
25
+ // all CppSignature's for an operator will match, even if they're registered
26
+ // with different calling conventions.
27
+ // See Note [Plumbing Keys Through The Dispatcher]
28
+ using decayed_function_type =
29
+ typename c10::remove_DispatchKeySet_arg_from_func<
30
+ std::decay_t<FuncType>>::func_type;
31
+
32
+ return CppSignature(std::type_index(typeid(decayed_function_type)));
33
+ }
34
+
35
+ std::string name() const {
36
+ return c10::demangle(signature_.name());
37
+ }
38
+
39
+ friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
40
+ if (lhs.signature_ == rhs.signature_) {
41
+ return true;
42
+ }
43
+ // Without RTLD_GLOBAL, the type_index comparison could yield false because
44
+ // they point to different instances of the RTTI data, but the types would
45
+ // still be the same. Let's check for that case too.
46
+ // Note that there still is a case where this might not work, i.e. when
47
+ // linking libraries of different compilers together, they might have
48
+ // different ways to serialize a type name. That, together with a missing
49
+ // RTLD_GLOBAL, would still fail this.
50
+ if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) {
51
+ return true;
52
+ }
53
+
54
+ return false;
55
+ }
56
+
57
+ private:
58
+ explicit CppSignature(std::type_index signature)
59
+ : signature_(std::move(signature)) {}
60
+ std::type_index signature_;
61
+ };
62
+
63
+ inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
64
+ return !(lhs == rhs);
65
+ }
66
+
67
+ } // namespace c10::impl
phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Variadic.h>
4
+ #include <ATen/core/function_schema.h>
5
+ #include <ATen/core/jit_type.h>
6
+ #include <ATen/core/stack.h>
7
+ #include <c10/core/DispatchKeySet.h>
8
+ #include <c10/util/Bitset.h>
9
+ #include <c10/util/irange.h>
10
+ #include <cstdint>
11
+
12
+ namespace c10 {
13
+
14
+ namespace impl {
15
+
16
+ // Take a DispatchKeySet for a Tensor and determine what the actual dispatch
17
+ // DispatchKey should be, taking into account TLS, and skipping backends which
18
+ // fall through.
19
+ //
20
+ // Unlike Tensor::key_set(), the value of this on a tensor can change depending
21
+ // on TLS.
22
+ //
23
+ // NB: If there is no valid dispatch key, this will return Undefined
24
+ inline DispatchKeySet computeDispatchKeySet(
25
+ DispatchKeySet ks,
26
+ // The key mask lets us eliminate (by zero entries) keys which should not
27
+ // be considered for dispatch. There are two cases when we use this:
28
+ //
29
+ // - If an operator's dispatch table contains a fallthrough entry, we
30
+ // should bypass it entirely when finding the key
31
+ // - If a user invokes with redispatch, the mask lets us
32
+ // zero out the key the user asked us to stop.
33
+ //
34
+ // These excluded backends are NOT tracked in the TLS, but must be applied
35
+ // AFTER TLS (since the backend may have been introduced for consideration
36
+ // by the included TLS), which is why you have to pass them in to this
37
+ // function (as opposed to just applying it to the input 'ks').
38
+ DispatchKeySet key_mask) {
39
+ c10::impl::LocalDispatchKeySet local =
40
+ c10::impl::tls_local_dispatch_key_set();
41
+ // TODO: It's a bit irritating that we have to do logical ORs here, it would
42
+ // be nice to only do one. Can always_included be folded into the TLS? Well,
43
+ // it's a bit troublesome, because fastpath TLS access requires the type of
44
+ // the TLS in question to be zero-initialized, so you don't actually win
45
+ // anything in that case.
46
+ return (((ks | local.included_) - local.excluded_) & key_mask);
47
+ }
48
+
49
+ } // namespace impl
50
+
51
+ namespace detail {
52
+ // A small gadget to extract the DispatchKeySet from types which are known
53
+ // to have it. Used to extract dispatch keys from unboxed calls.
54
+ struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
55
+ DispatchKeySet ts;
56
+ void operator()(const at::Tensor& x) {
57
+ ts = ts | x.key_set();
58
+ }
59
+ void operator()(const std::optional<at::Tensor>& x) {
60
+ if (x.has_value()) {
61
+ ts = ts | x->key_set();
62
+ }
63
+ }
64
+ void operator()(at::ArrayRef<at::Tensor> xs) {
65
+ for (const auto& x : xs) {
66
+ ts = ts | x.key_set();
67
+ }
68
+ }
69
+ // Tensor?[] translates to this case.
70
+ void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
71
+ for (std::optional<at::Tensor> x : xs) {
72
+ if (x.has_value()) {
73
+ ts = ts | x.value().key_set();
74
+ }
75
+ }
76
+ }
77
+ // Structured Tensor[] translates to this case
78
+ void operator()(const at::ITensorListRef& xs) {
79
+ for (const auto& x : xs) {
80
+ ts = ts | x.key_set();
81
+ }
82
+ }
83
+ [[noreturn]] void operator()(at::ArrayRef<std::optional<at::Tensor>>) {
84
+ // Just checking that the handling of Tensor?[] didn't change.
85
+ TORCH_INTERNAL_ASSERT(false);
86
+ }
87
+ void operator()(const at::Generator& gen) {
88
+ if (gen.defined()) {
89
+ ts = ts | gen.key_set();
90
+ }
91
+ }
92
+ void operator()(const std::optional<at::Generator>& gen) {
93
+ if (gen.has_value() && gen->defined()) {
94
+ ts = ts | gen->key_set();
95
+ }
96
+ }
97
+ template <typename T>
98
+ void operator()(const T&) {
99
+ // do nothing
100
+ }
101
+ };
102
+
103
+ // NB: take by const reference (Don't do universal forwarding here! You
104
+ // don't want to move into this function!)
105
+ template <typename... Args>
106
+ DispatchKeySet multi_dispatch_key_set(const Args&... args) {
107
+ return MultiDispatchKeySet().apply(args...).ts;
108
+ }
109
+ } // namespace detail
110
+
111
+ /**
112
+ * An instance of DispatchKeyExtractor knows how to get a dispatch key given
113
+ * a list of arguments for an operator call.
114
+ *
115
+ * The instance is specific for a certain operator as:
116
+ * - In boxed dispatch, different operators have different ways to extract
117
+ * the dispatch key (e.g. different numbers of arguments), and we precompute
118
+ * the stack locations we should look at; and
119
+ * - In all dispatch, some backends should be excluded from dispatch because
120
+ * they have been registered as fallthrough. The set of excluded backends
121
+ * varies from operator, as some operators may have overridden the
122
+ * fallthrough with custom behavior.
123
+ *
124
+ * Note - this should maintain identical impl to the py dispatcher key
125
+ * extraction logic at pytorch/torch/dispatcher.py
126
+ */
127
+ struct TORCH_API DispatchKeyExtractor final {
128
+ public:
129
+ static DispatchKeyExtractor make(const FunctionSchema& schema) {
130
+ return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
131
+ }
132
+
133
+ static DispatchKeyExtractor makeUninitialized() {
134
+ return DispatchKeyExtractor(c10::utils::bitset());
135
+ }
136
+
137
+ void registerSchema(const FunctionSchema& schema) {
138
+ TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
139
+ dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
140
+ }
141
+ void deregisterSchema() {
142
+ dispatch_arg_indices_reverse_ = c10::utils::bitset();
143
+ }
144
+
145
+ DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
146
+ DispatchKeySet ks;
147
+ dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t
148
+ reverse_arg_index) {
149
+ const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
150
+ if (C10_LIKELY(ivalue.isTensor())) {
151
+ // NB: Take care not to introduce a refcount bump (there's
152
+ // no safe toTensorRef method, alas)
153
+ ks = ks | ivalue.unsafeToTensorImpl()->key_set();
154
+ } else if (C10_UNLIKELY(ivalue.isTensorList())) {
155
+ // NB: use toListRef as it doesn't induce refcount bumps
156
+ // (toTensorListRef is not a thing)
157
+ for (const auto& nv : ivalue.toListRef()) {
158
+ auto* tensor = nv.unsafeToTensorImpl();
159
+ ks = ks | tensor->key_set();
160
+ }
161
+ }
162
+ // Tensor?[] translates to a c10::List<IValue> so we need to peek inside
163
+ else if (C10_UNLIKELY(ivalue.isList())) {
164
+ for (const auto& elt : ivalue.toListRef()) {
165
+ if (elt.isTensor()) {
166
+ ks = ks | elt.toTensor().key_set();
167
+ }
168
+ }
169
+ }
170
+ });
171
+ // Keys that are fallthrough should be skipped
172
+ if (requiresBitsetPerBackend_) {
173
+ c10::impl::LocalDispatchKeySet tls =
174
+ c10::impl::tls_local_dispatch_key_set();
175
+ auto backend_idx =
176
+ ((ks | tls.included_) - tls.excluded_).getBackendIndex();
177
+ return impl::computeDispatchKeySet(
178
+ ks, nonFallthroughKeysPerBackend_[backend_idx]);
179
+ } else {
180
+ return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
181
+ }
182
+ }
183
+
184
+ template <class... Args>
185
+ DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
186
+ auto ks = detail::multi_dispatch_key_set(args...);
187
+ // Keys that are fallthrough should be skipped
188
+ if (requiresBitsetPerBackend_) {
189
+ c10::impl::LocalDispatchKeySet tls =
190
+ c10::impl::tls_local_dispatch_key_set();
191
+ auto backend_idx =
192
+ ((ks | tls.included_) - tls.excluded_).getBackendIndex();
193
+ return impl::computeDispatchKeySet(
194
+ ks, nonFallthroughKeysPerBackend_[backend_idx]);
195
+ } else {
196
+ return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
197
+ }
198
+ }
199
+
200
+ void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
201
+
202
+ std::string dumpState() const;
203
+ void checkInvariants(const FunctionSchema& schema) const;
204
+
205
+ private:
206
+ static bool isDispatchType(const Type& type) {
207
+ // Checking isSubtypeOf on a DynamicType heap-allocates a
208
+ // DynamicType version of the argument if it's not a DynamicType
209
+ // already, and this has measurable overhead during startup.
210
+ #ifdef C10_MOBILE
211
+ struct CachedTypes {
212
+ DynamicTypePtr listOfTensors;
213
+ DynamicTypePtr listOfOptionalTensors;
214
+ DynamicTypePtr optionalOfTensor;
215
+ };
216
+ static const CachedTypes ct = {
217
+ DynamicType::create(*ListType::ofTensors()),
218
+ DynamicType::create(*ListType::ofOptionalTensors()),
219
+ DynamicType::create(*OptionalType::ofTensor())};
220
+ return type.isSubtypeOf(c10::TypeFactory::get<TensorType>()) ||
221
+ type.isSubtypeOf(ct.listOfTensors) ||
222
+ type.isSubtypeOf(ct.listOfOptionalTensors) ||
223
+ type.isSubtypeOf(ct.optionalOfTensor);
224
+ #else // C10_MOBILE
225
+ return type.isSubtypeOf(*TensorType::get()) ||
226
+ type.isSubtypeOf(*ListType::ofTensors()) ||
227
+ type.isSubtypeOf(*ListType::ofOptionalTensors()) ||
228
+ type.isSubtypeOf(*OptionalType::ofTensor());
229
+ #endif // C10_MOBILE
230
+ }
231
+ static c10::utils::bitset makeBitsetForDispatchArgs(
232
+ const FunctionSchema& schema) {
233
+ TORCH_CHECK(
234
+ schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
235
+ "The function schema has ",
236
+ schema.arguments().size(),
237
+ " arguments but this PyTorch build only supports ",
238
+ c10::utils::bitset::NUM_BITS());
239
+ c10::utils::bitset dispatch_arg_indices_reverse;
240
+ for (const auto index : c10::irange(schema.arguments().size())) {
241
+ if (isDispatchType(*schema.arguments()[index].type())) {
242
+ dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
243
+ }
244
+ }
245
+ return dispatch_arg_indices_reverse;
246
+ }
247
+
248
+ explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
249
+ : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse),
250
+ nonFallthroughKeys_(DispatchKeySet::FULL) {
251
+ for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
252
+ nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
253
+ }
254
+ }
255
+
256
+ // this is a bitset that has ones for each argument index which has to be
257
+ // considered for dispatch. This avoids having to iterate over the stack
258
+ // to find all the tensors. The bits are stored in reverse order, i.e.
259
+ // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
260
+ // the top of the stack (i.e. the i-th last argument of the function)
261
+ // is relevant for dispatch.
262
+ // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just
263
+ // means you must do the fallthrough
264
+ c10::utils::bitset dispatch_arg_indices_reverse_;
265
+
266
+ // Set of functionality keys for which the operator does NOT have fallthrough
267
+ // kernel.
268
+ DispatchKeySet nonFallthroughKeys_;
269
+ // Set of functionality keys for which the operator does NOT have fallthrough
270
+ // kernel, defined PER BACKEND. This is only needed if we know that the
271
+ // operator has a different set of fallthroughs defined for some backends.
272
+ std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
273
+ // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast
274
+ // path), or if we need to fall back to the slower path and check
275
+ // nonFallthroughKeysPerBackend_
276
+ bool requiresBitsetPerBackend_{false};
277
+ };
278
+
279
+ } // namespace c10