lsmpp commited on
Commit
5fa88dc
·
verified ·
1 Parent(s): 10f1e6a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h +3 -0
  2. .venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h +13 -0
  3. .venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h +46 -0
  4. .venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h +161 -0
  5. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h +48 -0
  6. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h +2 -0
  7. .venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h +737 -0
  8. .venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h +24 -0
  9. .venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h +139 -0
  10. .venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h +33 -0
  11. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h +396 -0
  12. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h +208 -0
  13. .venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h +13 -0
  14. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h +48 -0
  15. .venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h +332 -0
  16. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h +25 -0
  17. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h +191 -0
  18. .venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h +39 -0
  19. .venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h +631 -0
  20. .venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h +203 -0
  21. .venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h +111 -0
  22. .venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h +491 -0
  23. .venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h +353 -0
  24. .venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h +194 -0
  25. .venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h +143 -0
  26. .venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h +187 -0
  27. .venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h +240 -0
  28. .venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h +35 -0
  29. .venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h +22 -0
  30. .venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h +84 -0
  31. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h +25 -0
  32. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h +14 -0
  33. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h +1 -0
  34. .venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h +1 -0
  35. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h +98 -0
  36. .venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h +275 -0
  37. .venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h +1056 -0
  38. .venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h +0 -0
  39. .venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h +17 -0
  40. .venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h +175 -0
  41. .venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h +1 -0
  42. .venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h +21 -0
  43. .venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h +83 -0
  44. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h +92 -0
  45. .venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h +94 -0
  46. .venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h +162 -0
  47. .venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h +2294 -0
  48. .venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h +204 -0
  49. .venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h +213 -0
  50. .venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h +106 -0
.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Export.h>
4
+
5
+ namespace c10 {
6
+ struct OperatorName;
7
+ }
8
+
9
+ namespace at {
10
+
11
+ // check if an op is a custom op (i.e. did not come from native_functions.yaml)
12
+ TORCH_API bool is_custom_op(const c10::OperatorName& opName);
13
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/core/QScheme.h>
3
+
4
+ // Forward declarations of core ATen types used in dispatch functions
5
+ namespace c10 {
6
+
7
+ template<typename T>
8
+ class List;
9
+ template<typename T>
10
+ class IListRef;
11
+ class Stream;
12
+ class Scalar;
13
+ class SymInt;
14
+ class SymIntList;
15
+ struct Storage;
16
+ struct TensorOptions;
17
+ template <typename T>
18
+ class ArrayRef;
19
+ template <typename T>
20
+ class OptionalArrayRef;
21
+
22
+ } // namespace c10
23
+
24
+ namespace at {
25
+
26
+ class Tensor;
27
+ class OptionalTensorRef;
28
+ struct Dimname;
29
+ struct Generator;
30
+ using TensorList = c10::ArrayRef<Tensor>;
31
+ using ITensorListRef = c10::IListRef<Tensor>;
32
+ using IOptTensorListRef = c10::IListRef<OptionalTensorRef>;
33
+ using DimnameList = c10::ArrayRef<Dimname>;
34
+ using IntArrayRef = c10::ArrayRef<int64_t>;
35
+ using OptionalIntArrayRef = c10::OptionalArrayRef<int64_t>;
36
+ using OptionalSymIntArrayRef = c10::OptionalArrayRef<c10::SymInt>;
37
+
38
+ using c10::Stream;
39
+ using c10::Storage;
40
+ using c10::QScheme;
41
+ using c10::Scalar;
42
+ using c10::SymInt;
43
+ using c10::SymIntList;
44
+ using c10::TensorOptions;
45
+
46
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // This global header must not depend on native_functions.yaml or
2
+ // incremental builds will be next to useless
3
+ #pragma push_macro("TORCH_ASSERT_NO_OPERATORS")
4
+ #define TORCH_ASSERT_NO_OPERATORS
5
+
6
+ #include <cinttypes>
7
+
8
+ // This list of headers was generated using a script that finds
9
+ // high-impact headers and then manually tweaked to remove OS specific
10
+ // or duplicate headers (e.g. <cassert> and <assert.h>) and to remove
11
+ // "impl" headers (e.g BFloat16-inl.h or complex_math.h in c10).
12
+
13
+ // To generate the initial list:
14
+ // 1. Build pytorch from scratch with all build caching disabled
15
+ // 2. Generate a build trace with ninjatracing (https://github.com/nico/ninjatracing)
16
+ // $ ninjatracing /path/to/pytorch/build/.ninja_log > trace_all.json
17
+ // 3. Run pch_gen.py from https://github.com/peterbell10/build_analysis/
18
+ // $ python pch_gen.py --threshold .80 --target torch_cpu --build_dir /path/to/pytorch/build --trace trace_all.json
19
+ // Where the threshold can be tweaked until c10 and some of ATen
20
+ // core are included but TORCH_ASSERT_NO_OPERATORS still passes.
21
+
22
+ #include <cerrno>
23
+ #include <cmath>
24
+ #include <cstddef>
25
+ #include <cstdint>
26
+ #include <cstdlib>
27
+ #include <cstring>
28
+
29
+ #include <algorithm>
30
+ #include <array>
31
+ #include <atomic>
32
+ #include <chrono>
33
+ #include <complex>
34
+ #include <deque>
35
+ #include <exception>
36
+ #include <functional>
37
+ #include <initializer_list>
38
+ #include <iomanip>
39
+ #include <iosfwd>
40
+ #include <iterator>
41
+ #include <limits>
42
+ #include <list>
43
+ #include <map>
44
+ #include <memory>
45
+ #include <mutex>
46
+ #include <new>
47
+ #include <numeric>
48
+ #include <ostream>
49
+ #include <sstream>
50
+ #include <stdexcept>
51
+ #include <string>
52
+ #include <string_view>
53
+ #include <tuple>
54
+ #include <type_traits>
55
+ #include <typeindex>
56
+ #include <typeinfo>
57
+ #include <unordered_map>
58
+ #include <unordered_set>
59
+ #include <utility>
60
+ #include <vector>
61
+
62
+ #include <c10/core/Allocator.h>
63
+ #include <c10/core/AutogradState.h>
64
+ #include <c10/core/Backend.h>
65
+ #include <c10/core/DefaultDtype.h>
66
+ #include <c10/core/Device.h>
67
+ #include <c10/core/DeviceType.h>
68
+ #include <c10/core/DispatchKey.h>
69
+ #include <c10/core/DispatchKeySet.h>
70
+ #include <c10/core/GeneratorImpl.h>
71
+ #include <c10/core/InferenceMode.h>
72
+ #include <c10/core/Layout.h>
73
+ #include <c10/core/MemoryFormat.h>
74
+ #include <c10/core/OptionalRef.h>
75
+ #include <c10/core/QScheme.h>
76
+ #include <c10/core/Scalar.h>
77
+ #include <c10/core/ScalarType.h>
78
+ #include <c10/core/ScalarTypeToTypeMeta.h>
79
+ #include <c10/core/Storage.h>
80
+ #include <c10/core/StorageImpl.h>
81
+ #include <c10/core/SymBool.h>
82
+ #include <c10/core/SymFloat.h>
83
+ #include <c10/core/SymInt.h>
84
+ #include <c10/core/SymIntArrayRef.h>
85
+ #include <c10/core/SymNodeImpl.h>
86
+ #include <c10/core/TensorImpl.h>
87
+ #include <c10/core/TensorOptions.h>
88
+ #include <c10/core/UndefinedTensorImpl.h>
89
+ #include <c10/core/WrapDimMinimal.h>
90
+ #include <c10/core/impl/LocalDispatchKeySet.h>
91
+ #include <c10/core/impl/PyInterpreter.h>
92
+ #include <c10/core/impl/SizesAndStrides.h>
93
+
94
+ #include <c10/macros/Export.h>
95
+ #include <c10/macros/Macros.h>
96
+
97
+ #include <c10/util/AlignOf.h>
98
+ #include <c10/util/ArrayRef.h>
99
+ #include <c10/util/BFloat16.h>
100
+ #include <c10/util/C++17.h>
101
+ #include <c10/util/ConstexprCrc.h>
102
+ #include <c10/util/Deprecated.h>
103
+ #include <c10/util/DimVector.h>
104
+ #include <c10/util/Exception.h>
105
+ #include <c10/util/ExclusivelyOwned.h>
106
+ #include <c10/util/Flags.h>
107
+ #include <c10/util/Float8_e4m3fn.h>
108
+ #include <c10/util/Float8_e5m2.h>
109
+ #include <c10/util/Float8_e4m3fnuz.h>
110
+ #include <c10/util/Float8_e5m2fnuz.h>
111
+ #include <c10/util/FunctionRef.h>
112
+ #include <c10/util/Half.h>
113
+ #include <c10/util/IdWrapper.h>
114
+ #include <c10/util/Logging.h>
115
+ #include <c10/util/MaybeOwned.h>
116
+ #include <c10/util/Metaprogramming.h>
117
+ #include <c10/util/Optional.h>
118
+ #include <c10/util/Registry.h>
119
+ #include <c10/util/SmallVector.h>
120
+ #include <c10/util/StringUtil.h>
121
+ #include <c10/util/ThreadLocalDebugInfo.h>
122
+ #include <c10/util/Type.h>
123
+ #include <c10/util/TypeCast.h>
124
+ #include <c10/util/TypeIndex.h>
125
+ #include <c10/util/TypeList.h>
126
+ #include <c10/util/TypeSafeSignMath.h>
127
+ #include <c10/util/TypeTraits.h>
128
+ #include <c10/util/UniqueVoidPtr.h>
129
+ #include <c10/util/accumulate.h>
130
+ #include <c10/util/bit_cast.h>
131
+ #include <c10/util/bits.h>
132
+ #include <c10/util/complex.h>
133
+ #include <c10/util/floating_point_utils.h>
134
+ #include <c10/util/intrusive_ptr.h>
135
+ #include <c10/util/irange.h>
136
+ #include <c10/util/llvmMathExtras.h>
137
+ #include <c10/util/python_stub.h>
138
+ #include <c10/util/qint32.h>
139
+ #include <c10/util/qint8.h>
140
+ #include <c10/util/quint2x4.h>
141
+ #include <c10/util/quint4x2.h>
142
+ #include <c10/util/quint8.h>
143
+ #include <c10/util/safe_numerics.h>
144
+ #include <c10/util/string_utils.h>
145
+ #include <c10/util/string_view.h>
146
+ #include <c10/util/typeid.h>
147
+
148
+ #include <ATen/StorageUtils.h>
149
+ #include <ATen/core/ATen_fwd.h>
150
+ #include <ATen/core/DeprecatedTypeProperties.h>
151
+ #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
152
+ #include <ATen/core/DimVector.h>
153
+ #include <ATen/core/Dimname.h>
154
+ #include <ATen/core/Generator.h>
155
+ #include <ATen/core/NamedTensor.h>
156
+ #include <ATen/core/QuantizerBase.h>
157
+ #include <ATen/core/TensorAccessor.h>
158
+ #include <ATen/core/TensorBase.h>
159
+ #include <ATen/core/symbol.h>
160
+
161
+ #pragma pop_macro("TORCH_ASSERT_NO_OPERATORS")
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // A fixed-size array type usable from both host and
4
+ // device code.
5
+
6
+ #include <c10/macros/Macros.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ namespace at::detail {
10
+
11
+ template <typename T, int size_>
12
+ struct Array {
13
+ // NOLINTNEXTLINE(*c-array*)
14
+ T data[size_];
15
+
16
+ C10_HOST_DEVICE T operator[](int i) const {
17
+ return data[i];
18
+ }
19
+ C10_HOST_DEVICE T& operator[](int i) {
20
+ return data[i];
21
+ }
22
+ #if defined(USE_ROCM)
23
+ C10_HOST_DEVICE Array() = default;
24
+ C10_HOST_DEVICE Array(const Array&) = default;
25
+ C10_HOST_DEVICE Array& operator=(const Array&) = default;
26
+ C10_HOST_DEVICE Array(Array&&) = default;
27
+ C10_HOST_DEVICE Array& operator=(Array&&) = default;
28
+ C10_HOST_DEVICE ~Array() = default;
29
+ #else
30
+ Array() = default;
31
+ Array(const Array&) = default;
32
+ Array& operator=(const Array&) = default;
33
+ Array(Array&&) noexcept = default;
34
+ Array& operator=(Array&&) noexcept = default;
35
+ ~Array() = default;
36
+ #endif
37
+ static constexpr int size() {
38
+ return size_;
39
+ }
40
+ // Fill the array with x.
41
+ C10_HOST_DEVICE Array(T x) {
42
+ for (int i = 0; i < size_; i++) {
43
+ data[i] = x;
44
+ }
45
+ }
46
+ };
47
+
48
+ } // namespace at::detail
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #include <c10/util/Backtrace.h>
2
+ #include <c10/util/Type.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Allocator.h>
4
+ #include <c10/core/Stream.h>
5
+ #include <c10/core/thread_pool.h>
6
+ #include <c10/util/flat_hash_map.h>
7
+ #include <c10/util/llvmMathExtras.h>
8
+ #include <optional>
9
+
10
+ #include <deque>
11
+ #include <mutex>
12
+
13
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
14
+ namespace at {
15
+
16
+ using c10::CachingAllocator::Stat;
17
+ using c10::CachingAllocator::DurationStat;
18
+
19
+ /**
20
+ * HostBlock is typically a fundamental memory block used in pinned memory. It
21
+ * is likely related to Event and Stream of device runtime. It is probably a
22
+ * base struct or interface that can be inherited and extended by each backend.
23
+ */
24
+ template <typename S>
25
+ struct HostBlock {
26
+ // constructor for search key
27
+ HostBlock(size_t size) : size_(size) {}
28
+
29
+ HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {}
30
+
31
+ std::mutex mutex_;
32
+ size_t size_{0}; // block size in bytes
33
+ void* ptr_{nullptr}; // memory address
34
+ bool allocated_{false}; // in-use flag
35
+ size_t event_count_{0}; // number of related events
36
+ ska::flat_hash_set<S> streams_; // streams on which the block was used
37
+ };
38
+
39
+ template <typename B>
40
+ struct alignas(64) FreeBlockList {
41
+ std::mutex mutex_;
42
+ std::deque<B*> list_;
43
+ };
44
+
45
+ namespace {
46
+ // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes
47
+ // NOLINTNEXTLINE(misc-definitions-in-headers)
48
+ constexpr size_t MAX_SIZE_INDEX = 64;
49
+ }
50
+
51
+ // Struct containing memory allocator summary statistics for host.
52
+ struct TORCH_API HostStats {
53
+ // COUNT: allocations requested by client code. Note that active
54
+ // count can be extracted by looking at current allocations
55
+ Stat allocation;
56
+ // COUNT: number of allocated segments from host memory allocation.
57
+ Stat segment;
58
+
59
+ // SUM: bytes allocated by this memory alocator. Note that active bytes
60
+ // can be extracted by looking at current bytes allocated
61
+ Stat allocated_bytes;
62
+ // SUM: bytes reserved by this memory allocator (both free and used)
63
+ Stat reserved_bytes;
64
+
65
+ // SUM: time spent in cudaHostAlloc/cudaHostRegister in microseconds
66
+ DurationStat host_alloc_time;
67
+
68
+ // SUM: time spent in cudaHostFree/cudaHostUnregister in microseconds
69
+ DurationStat host_free_time;
70
+
71
+ // COUNT: number of times cudaHostAlloc/cudaHostRegister was called because
72
+ // the request could not be satisfied from existing free blocks.
73
+ int64_t num_host_alloc = 0; // This is derived from segment or timing
74
+
75
+ // COUNT: number of times cudaHostFree/cudaHostUnregister was called.
76
+ int64_t num_host_free = 0; // This is derived from segment or timing
77
+ };
78
+
79
+ // Struct containing memory allocator summary statistics for host, as they
80
+ // are staged for reporting. This is a temporary struct that is used to
81
+ // avoid locking the allocator while collecting stats.
82
+ struct alignas(64) HostStatsStaged {
83
+ std::mutex timing_mutex_;
84
+ // COUNT: allocations requested by client code resulting in a new segment/block allocation
85
+ // LOCK: access to this stat is protected by the allocator's blocks_mutex_
86
+ Stat allocation;
87
+ // SUM: bytes within active memory blocks, including blocks that are
88
+ // currently in the free list.
89
+ // LOCK: access to this stat is protected by the allocator's blocks_mutex_
90
+ Stat allocated_bytes;
91
+ // COUNT: number of allocations per bucket
92
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
93
+ std::vector<Stat> allocation_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
94
+ // SUM: bytes of allocation per bucket
95
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
96
+ std::vector<Stat> allocated_bytes_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
97
+ // SUM: time spent in cudaHostAlloc/cudaHostRegister
98
+ // LOCK: access to this stat is protected by the timing_mutex_
99
+ DurationStat host_alloc_time;
100
+ // SUM: time spent in cudaHostFree/cudaHostUnregister
101
+ // LOCK: access to this stat is protected by the timing_mutex_
102
+ DurationStat host_free_time;
103
+ };
104
+
105
+ /**
106
+ * Note [HostAllocator design]
107
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
108
+ * We have three key data structures - the free list which stores blocks that
109
+ * are not currently used, the block list which stores all blocks that have been
110
+ * allocated, and the event queue which stores runtime events and their
111
+ * corresponding blocks.
112
+ *
113
+ * Each of these are protected by a separate mutex. The key design principles
114
+ * are to 1) only hold each mutex for the minimal amount of time possible, 2)
115
+ * never do any possible expensive operations (such as CUDA runtime API calls)
116
+ * while holding the lock.
117
+ *
118
+ * There are four public methods: allocate, free, record_event and empty_cache.
119
+ * 1) In the allocate path, we first check to see if we can service our
120
+ * request from this free list, and otherwise we create a new block with
121
+ * allocate_host_memory.
122
+ * 2) In the free path, we insert events (if required) into the event queue,
123
+ * and if possible insert our block back into the free list. In allocate, we
124
+ * first eagerly query events until we find one that is not ready, and insert
125
+ * the corresponding block onto the free list if all the events recorded for a
126
+ * block are ready.
127
+ * 3) In the record_event path, we simply insert the given stream into the set
128
+ * of streams tracked by the specified block. This set of streams is then
129
+ * consumed in the free path.
130
+ * 4) In the empty_cache path, we flush any available blocks into the free
131
+ * list. Remove all element of free list, then remove them from block list and
132
+ * release the associated pinned memory allocation via free_block.
133
+ *
134
+ * We generalize the caching host allocator into two parts: interface and
135
+ * implementation. For any new backend looking to integrate with host allocator
136
+ * and reuse caching mechanism, these two parts are necessary to be specialized.
137
+ *
138
+ * For the implementation, we provide a CachingHostAllocatorImpl struct
139
+ * to abstract the caching mechanism. Any backend needs to provide a customized
140
+ * implementation by specializing its own public functions and the related
141
+ * runtime functions. Its template parameter S represents runtime Stream, E
142
+ * denotes runtime Event, B indicates the fundamental memory block.
143
+ *
144
+ * For the interface, we provide a CachingHostAllocatorInterface struct as an
145
+ * interface. Any backend needs to derive its own host allocator from this
146
+ * interface. Its template parameter T refers to an implementation that
147
+ * inherited from CachingHostAllocatorImpl.
148
+ *
149
+ * So this design can share the caching mechanism across each backend, and
150
+ * provide flexibility to each backend. A backend can choose to follow this
151
+ * implementation or reuse them by extending and overriding them as necessary.
152
+ * Taking CUDA as an example, it specializes runtime related functions to reuse
153
+ * the caching mechanism. Additionally, it extends the allocator's functionality
154
+ * by adding the allocWithCudaHostRegister function to support page-locking the
155
+ * memory range used by CUDA. Of course, you can also refer to
156
+ * XPUCachingHostAllocator, which is a host caching allocator supported on XPU
157
+ * backend, to implement a basic host caching allocator.
158
+ *
159
+ * Some of the invariants here are less strict than they could be - for example,
160
+ * we do not enforce that free(Block* block) => block->event_count == 0. This is
161
+ * for compatibility reasons, and we can explore enforcing these in subsequent
162
+ * versions.
163
+ *
164
+ * Note that this caching host allocator does not split larger allocations into
165
+ * smaller blocks, unlike the caching device allocator.
166
+ *
167
+ * In order to gather statistics about caching host allocator while minimally
168
+ * impacting performance, we use a HostStatsStaged struct to stage the stats
169
+ * before reporting them. This is done to avoid adding new locks to the allocator.
170
+ * Collecting stats is carefully done under existing locks, and then the staged
171
+ * stats are converted to the final stats when getStats is called. At that time
172
+ * we hold the same locks as empty_cache, to ensure the fidelity of the stats.
173
+ */
174
+
175
+ template <
176
+ typename S,
177
+ typename E,
178
+ typename B = HostBlock<S>>
179
+ struct CachingHostAllocatorImpl {
180
+ virtual ~CachingHostAllocatorImpl() {
181
+ active_ = false;
182
+ if (pinned_use_background_threads()) {
183
+ getBackgroundThreadPool()->waitWorkComplete();
184
+ }
185
+ }
186
+
187
+ public:
188
+ // return data_ptr and block pair.
189
+ virtual std::pair<void*, void*> allocate(size_t size) {
190
+ if (size == 0) {
191
+ return {nullptr, nullptr};
192
+ }
193
+
194
+ // If we are using background threads, we can process events in the
195
+ // background.
196
+ if (!pinned_use_background_threads()) {
197
+ process_events();
198
+ }
199
+
200
+ // Round up the allocation to the nearest power of two to improve reuse.
201
+ // These power of two sizes are also used to index into the free list.
202
+ size_t roundSize = c10::llvm::PowerOf2Ceil(size);
203
+
204
+ // First, try to allocate from the free list
205
+ auto* block = get_free_block(roundSize);
206
+ if (block) {
207
+ return {block->ptr_, reinterpret_cast<void*>(block)};
208
+ }
209
+
210
+ // Check in the recently freed blocks with pending events to see if we
211
+ // can reuse them. Call get_free_block again after processing events
212
+ if (pinned_use_background_threads()) {
213
+ process_events_for_specific_size(roundSize);
214
+ block = get_free_block(roundSize);
215
+ if (block) {
216
+ return {block->ptr_, reinterpret_cast<void*>(block)};
217
+ }
218
+
219
+ // Launch the background thread and process events in a loop.
220
+ static bool background_thread_flag [[maybe_unused]] = [this] {
221
+ getBackgroundThreadPool()->run([&]() {
222
+ while (active_) {
223
+ process_events();
224
+ std::this_thread::sleep_for(std::chrono::microseconds(100));
225
+ }
226
+ });
227
+ return true;
228
+ }();
229
+ }
230
+
231
+ // Slow path: if we can't allocate from the cached free list, we need
232
+ // to create a new block.
233
+ void* ptr = nullptr;
234
+ allocate_host_memory(roundSize, &ptr);
235
+
236
+ // Then, create a new block.
237
+ block = new B(roundSize, ptr);
238
+ block->allocated_ = true;
239
+
240
+ add_allocated_block(block);
241
+ return {block->ptr_, reinterpret_cast<void*>(block)};
242
+ }
243
+
244
+ virtual void free(void* ctx) {
245
+ if (!ctx) {
246
+ return;
247
+ }
248
+
249
+ // Note: we can assume that free is correctly paired with alloc, and thus we
250
+ // do not need to look up the ctx in blocks_.
251
+ auto* block = reinterpret_cast<B*>(ctx);
252
+
253
+ std::optional<std::vector<E>> events;
254
+ {
255
+ std::lock_guard<std::mutex> g(block->mutex_);
256
+ block->allocated_ = false;
257
+ if (block->streams_.empty()) {
258
+ TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
259
+ } else {
260
+ events = std::vector<E>();
261
+ events->reserve(block->streams_.size());
262
+ for (auto stream : block->streams_) {
263
+ record_stream(events, stream);
264
+ }
265
+ block->event_count_ += events->size();
266
+ block->streams_.clear();
267
+ }
268
+ }
269
+
270
+ if (!events) {
271
+ auto index = size_index(block->size_);
272
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
273
+ free_list_[index].list_.push_back(block);
274
+ stats_.allocation_bucket_stats[index].decrease(1);
275
+ stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
276
+ } else {
277
+ // restore these events that record by used streams.
278
+ std::lock_guard<std::mutex> g(events_mutex_);
279
+ for (auto&& event : *events) {
280
+ events_.emplace_front(std::move(event), block);
281
+ }
282
+ }
283
+ }
284
+
285
+ virtual bool record_event(void* ptr, void* ctx, c10::Stream s) {
286
+ S stream = S(s);
287
+ auto* block = reinterpret_cast<B*>(ctx);
288
+
289
+ // Note: we need to check if the passed-in `ctx` is valid. This is because
290
+ // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on
291
+ // an arbitrary tensor, and is not guaranteed to correspond to a pinned
292
+ // memory allocation. Therefore, we need to check that `ctx` is valid before
293
+ // proceeding.
294
+ {
295
+ std::lock_guard<std::mutex> g(blocks_mutex_);
296
+ if (blocks_.find(block) != blocks_.end()) {
297
+ // Now we know this object is safe to access.
298
+ std::lock_guard<std::mutex> gb(block->mutex_);
299
+ TORCH_INTERNAL_ASSERT(block->allocated_);
300
+ block->streams_.insert(stream);
301
+ return true;
302
+ }
303
+ auto it = ptr_to_block_.find(ptr);
304
+ if (it != ptr_to_block_.end()) {
305
+ block = it->second;
306
+ std::lock_guard<std::mutex> g(block->mutex_);
307
+ TORCH_INTERNAL_ASSERT(block->allocated_);
308
+ block->streams_.insert(stream);
309
+ return true;
310
+ }
311
+ }
312
+
313
+ return false;
314
+ }
315
+
316
+ virtual void empty_cache() {
317
+ // Flush any available blocks into the free_list.
318
+ process_events();
319
+
320
+ // Remove all elements from the free list, remove them from the blocks
321
+ // list, and free the associated pinned memory allocation. This requires
322
+ // concurrently holding both the free list mutexes and the blocks mutex, and
323
+ // is the only function that concurrently holds multiple mutexes.
324
+ for (size_t i = 0; i < free_list_.size(); ++i) {
325
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
326
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
327
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
328
+
329
+ std::vector<B*> blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end());
330
+ free_list_[i].list_.clear();
331
+
332
+ for (auto* block : blocks_to_remove) {
333
+ blocks_.erase(block);
334
+ ptr_to_block_.erase(block->ptr_);
335
+ stats_.allocation.decrease(1);
336
+ stats_.allocated_bytes.decrease(block->size_);
337
+ free_block(block);
338
+ delete block;
339
+ }
340
+ }
341
+ }
342
+
343
+ inline size_t size_index(size_t size) {
344
+ return c10::llvm::Log2_64_Ceil(size);
345
+ }
346
+
347
+ virtual bool pinned_use_background_threads() {
348
+ return false;
349
+ }
350
+
351
+ virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
352
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
353
+ }
354
+
355
+ HostStats getStats() {
356
+ HostStats stats;
357
+
358
+ // To keep getStats lightweight we do *not* flush any available blocks
359
+ // into the free_list. This may skew the stats a bit.
360
+
361
+ auto add_bucket_stats = [](Stat& accumulator, const Stat& other) {
362
+ accumulator.allocated += other.allocated;
363
+ accumulator.current += other.current;
364
+ accumulator.freed += other.freed;
365
+ // Since peaks are measured per bucket independently, we add them up
366
+ // to estimate the total peak. This is not strictly correct, but it is
367
+ // the best approximation we can get after the fact.
368
+ accumulator.peak += other.peak;
369
+ };
370
+
371
+ // Accurate reading of memory stats requires concurrently holding both the
372
+ // free list mutexes and the blocks mutex. Previously, this was only done in
373
+ // empty_cache function.
374
+ for (size_t i = 0; i < free_list_.size(); ++i) {
375
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
376
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
377
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
378
+
379
+ // We collect the slow-path stats only once, since they are not collected
380
+ // per bucket (we pick index 0 arbitrarily). These are also all the host
381
+ // allocations, not taking into account caching and free lists.
382
+ if (i == 0) {
383
+ stats.segment = stats_.allocation;
384
+ stats.reserved_bytes = stats_.allocated_bytes;
385
+ stats.num_host_alloc = stats.segment.allocated;
386
+ stats.num_host_free = stats.segment.freed;
387
+ }
388
+
389
+ // Bucket stats need to be merged with the slow-path stats. We do this in
390
+ // a best effort manner, since we can't really replay the cached events per bucket.
391
+ add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]);
392
+ add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]);
393
+ }
394
+
395
+ // Get the timing stats
396
+ {
397
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
398
+
399
+ stats.host_alloc_time = stats_.host_alloc_time;
400
+ stats.host_free_time = stats_.host_free_time;
401
+ }
402
+
403
+ return stats;
404
+ }
405
+
406
+ void resetAccumulatedStats() {
407
+ // Reseting accumulated memory stats requires concurrently holding both the
408
+ // free list mutexes and the blocks mutex. Previously, this was only done in
409
+ // empty_cache function.
410
+ for (size_t i = 0; i < free_list_.size(); ++i) {
411
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
412
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
413
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
414
+
415
+ if (i == 0) {
416
+ stats_.allocation.reset_accumulated();
417
+ stats_.allocated_bytes.reset_accumulated();
418
+ }
419
+ stats_.allocation_bucket_stats[i].reset_accumulated();
420
+ stats_.allocated_bytes_bucket_stats[i].reset_accumulated();
421
+ }
422
+
423
+ // Also reset timing stats
424
+ {
425
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
426
+ stats_.host_alloc_time.reset_accumulated();
427
+ stats_.host_free_time.reset_accumulated();
428
+ }
429
+ }
430
+
431
+ void resetPeakStats() {
432
+ // Reseting peak memory stats requires concurrently holding both the
433
+ // free list mutexes and the blocks mutex. Previously, this was only done in
434
+ // empty_cache function.
435
+ for (size_t i = 0; i < free_list_.size(); ++i) {
436
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
437
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
438
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
439
+
440
+ if (i == 0) {
441
+ stats_.allocation.reset_peak();
442
+ stats_.allocated_bytes.reset_peak();
443
+ }
444
+ stats_.allocation_bucket_stats[i].reset_peak();
445
+ stats_.allocated_bytes_bucket_stats[i].reset_peak();
446
+ }
447
+
448
+ // Also reset timing stats
449
+ {
450
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
451
+ stats_.host_alloc_time.reset_peak();
452
+ stats_.host_free_time.reset_peak();
453
+ }
454
+ }
455
+
456
+ private:
457
+ virtual void add_allocated_block(B* block) {
458
+ std::lock_guard<std::mutex> g(blocks_mutex_);
459
+ blocks_.insert(block);
460
+ stats_.allocation.increase(1);
461
+ stats_.allocated_bytes.increase(block->size_);
462
+ ptr_to_block_.insert({block->ptr_, block});
463
+
464
+ // Unfortunately, we have to, on the slow path, quickly
465
+ // lock the bucket to record the allocation. This should
466
+ // be a rare event once the cache is warmed up.
467
+ auto size = block->size_;
468
+ auto index = size_index(size);
469
+ {
470
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
471
+ stats_.allocation_bucket_stats[index].increase(1);
472
+ stats_.allocated_bytes_bucket_stats[index].increase(size);
473
+ }
474
+ }
475
+
476
+ virtual B* get_free_block(size_t size) {
477
+ auto index = size_index(size);
478
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
479
+ if (!free_list_[index].list_.empty()) {
480
+ B* block = free_list_[index].list_.back();
481
+ free_list_[index].list_.pop_back();
482
+ block->allocated_ = true;
483
+ stats_.allocation_bucket_stats[index].increase(1);
484
+ stats_.allocated_bytes_bucket_stats[index].increase(size);
485
+ return block;
486
+ }
487
+ return nullptr;
488
+ }
489
+
490
+ virtual void process_events() {
491
+ // process all events until the last unready event, not for specific size.
492
+ process_events_for_specific_size(-1);
493
+ }
494
+
495
+ // If size is -1, process all events from backwards until the last unready
496
+ // event. Otherwise, process events for a specific size and on first ready block
497
+ // is found, add it to the free list and return.
498
+ virtual void process_events_for_specific_size(int64_t size) {
499
+ size_t event_count = 0;
500
+ size_t max_events = 0;
501
+ {
502
+ std::lock_guard<std::mutex> g(events_mutex_);
503
+ max_events = events_.size();
504
+ }
505
+
506
+ while (true) {
507
+ // Avoid calling cudaEventDestroy while holding a mutex, so move
508
+ // intermediate events out of the lock into this object.
509
+ // process the last event
510
+ std::optional<std::pair<E, B*>> processed;
511
+ {
512
+ std::lock_guard<std::mutex> g(events_mutex_);
513
+ if (!events_.empty()) {
514
+ processed = std::move(events_.back());
515
+ events_.pop_back();
516
+ }
517
+ }
518
+
519
+ if (!processed) {
520
+ return;
521
+ }
522
+
523
+ if (size != -1) {
524
+ if (event_count++ > max_events) {
525
+ {
526
+ std::lock_guard<std::mutex> g(events_mutex_);
527
+ events_.push_front(std::move(*processed));
528
+ }
529
+ return;
530
+ }
531
+ if (size != (int64_t)processed->second->size_) {
532
+ // if we are processing a specific size, and the size of the block
533
+ // doesn't match, we can't use it.
534
+ {
535
+ std::lock_guard<std::mutex> g(events_mutex_);
536
+ events_.push_front(std::move(*processed));
537
+ }
538
+ continue;
539
+ }
540
+ }
541
+
542
+ // otherwise, query the event
543
+ {
544
+ // now, see if we can handle this element
545
+ auto& event = processed->first;
546
+ if (!query_event(event)) {
547
+ // push the event onto the back if it's not ready.
548
+ {
549
+ std::lock_guard<std::mutex> g(events_mutex_);
550
+ if (size == -1) {
551
+ events_.push_back(std::move(*processed));
552
+ return;
553
+ } else {
554
+ events_.push_front(std::move(*processed));
555
+ continue;
556
+ }
557
+ }
558
+ }
559
+ }
560
+
561
+ // Process the events.
562
+ TORCH_INTERNAL_ASSERT(processed);
563
+ auto* block = processed->second;
564
+ bool available = false;
565
+ {
566
+ std::lock_guard<std::mutex> g(block->mutex_);
567
+ TORCH_INTERNAL_ASSERT(!block->allocated_)
568
+ block->event_count_--;
569
+ if (block->event_count_ == 0) {
570
+ available = true;
571
+ }
572
+ }
573
+
574
+ if (available) {
575
+ auto index = size_index(block->size_);
576
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
577
+ free_list_[index].list_.push_back(block);
578
+ stats_.allocation_bucket_stats[index].decrease(1);
579
+ stats_.allocated_bytes_bucket_stats[index].decrease(size);
580
+ if (size != -1) {
581
+ return;
582
+ }
583
+ }
584
+ }
585
+ }
586
+
587
+ TaskThreadPool* getBackgroundThreadPool() {
588
+ static TaskThreadPool* pool = new TaskThreadPool(1);
589
+ return pool;
590
+ }
591
+
592
+ /* These following functions are runtime-related. */
593
+
594
+ // Allocate page-locked memory on the host.
595
+ virtual void allocate_host_memory(size_t size, void** ptr) {
596
+ TORCH_CHECK_NOT_IMPLEMENTED(
597
+ false, "Not implemented for allocate_host_memory");
598
+ }
599
+
600
+ // Free block and release the pointer contained in block.
601
+ virtual void free_block(B* block) {
602
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
603
+ }
604
+
605
+ // Record an event on stream and store event into events.
606
+ virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
607
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
608
+ }
609
+
610
+ // Query event if it is completed.
611
+ virtual bool query_event(E& event) {
612
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
613
+ }
614
+
615
+ alignas(64) std::mutex blocks_mutex_;
616
+ ska::flat_hash_set<B*> blocks_; // block list
617
+ ska::flat_hash_map<void*, B*> ptr_to_block_;
618
+
619
+ // We keep free list as a vector of free lists, one for each power of two
620
+ // size. This allows us to quickly find a free block of the right size.
621
+ // We use deque to store per size free list and guard the list with its own
622
+ // mutex.
623
+ alignas(64) std::vector<FreeBlockList<B>> free_list_ =
624
+ std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
625
+
626
+ alignas(64) std::mutex events_mutex_;
627
+ std::deque<std::pair<E, B*>> events_; // event queue paired with block
628
+
629
+ // Indicates whether the object is active.
630
+ // Set to false in the destructor to signal background threads to stop.
631
+ std::atomic<bool> active_{true};
632
+ protected:
633
+ alignas(64) HostStatsStaged stats_;
634
+ };
635
+
636
+ struct TORCH_API HostAllocator : public at::Allocator {
637
+ // Associates the pinned memory allocation with a stream to track
638
+ // dependencies. This ensures the memory won't be reused until the stream's
639
+ // operations complete
640
+ virtual bool record_event(void* ptr, void* ctx, c10::Stream stream) = 0;
641
+
642
+ // Frees all cached pinned memory and returns it to the system, clearing the
643
+ // allocator's internal cache
644
+ virtual void empty_cache() = 0;
645
+
646
+ // Returns comprehensive statistics about the allocator's memory usage,
647
+ // allocation patterns, and timing metrics
648
+ virtual HostStats get_stats() = 0;
649
+
650
+ // Resets the cumulative allocation statistics
651
+ virtual void reset_accumulated_stats() = 0;
652
+
653
+ // Resets the peak memory usage metrics
654
+ virtual void reset_peak_stats() = 0;
655
+ };
656
+
657
+ template <typename T, c10::DeleterFnPtr deleteFunc>
658
+ struct CachingHostAllocatorInterface : public HostAllocator {
659
+ CachingHostAllocatorInterface() : impl_(std::make_unique<T>()) {}
660
+
661
+ at::DataPtr allocate(size_t size) override {
662
+ auto ptr_and_ctx = impl_->allocate(size);
663
+ return {
664
+ ptr_and_ctx.first,
665
+ ptr_and_ctx.second,
666
+ deleteFunc, // Use the template parameter deleter function
667
+ at::DeviceType::CPU};
668
+ }
669
+
670
+ void free(void* ctx) {
671
+ impl_->free(ctx);
672
+ }
673
+
674
+ bool record_event(void* ptr, void* ctx, c10::Stream stream) override {
675
+ return impl_->record_event(ptr, ctx, stream);
676
+ }
677
+
678
+ void empty_cache() override {
679
+ impl_->empty_cache();
680
+ }
681
+
682
+ void copy_data(void* dest, const void* src, std::size_t count)
683
+ const override {
684
+ impl_->copy_data(dest, src, count);
685
+ }
686
+
687
+ HostStats get_stats() override {
688
+ return impl_->getStats();
689
+ }
690
+
691
+ void reset_accumulated_stats() override {
692
+ impl_->resetAccumulatedStats();
693
+ }
694
+
695
+ void reset_peak_stats() override {
696
+ impl_->resetPeakStats();
697
+ }
698
+
699
+ std::unique_ptr<T> impl_;
700
+ };
701
+
702
+ #define DECLARE_HOST_ALLOCATOR(name, impl, deleter, instance) \
703
+ void deleter(void* ptr); \
704
+ struct name final \
705
+ : public at::CachingHostAllocatorInterface<impl, deleter> {}; \
706
+ static name instance; \
707
+ void deleter(void* ptr) { \
708
+ instance.free(ptr); \
709
+ }
710
+
711
+ /**
712
+ * Set the host allocator for DeviceType `device_type`. This allocator manages
713
+ * pinned memory on the host that can be accessed efficiently by the specified
714
+ * device type. Note that this function is not thread-safe.
715
+ */
716
+ TORCH_API void setHostAllocator(
717
+ at::DeviceType device_type,
718
+ at::HostAllocator* allocator,
719
+ uint8_t priority = 0);
720
+
721
+ TORCH_API at::HostAllocator* getHostAllocator(at::DeviceType device_type);
722
+
723
+ template <DeviceType device_type>
724
+ struct HostAllocatorRegistry {
725
+ explicit HostAllocatorRegistry(HostAllocator* allocator) {
726
+ at::setHostAllocator(device_type, allocator);
727
+ }
728
+ };
729
+
730
+ #define REGISTER_HOST_ALLOCATOR(device_type, allocator) \
731
+ namespace { \
732
+ static at::HostAllocatorRegistry<device_type> \
733
+ g_host_allocator_registry_instance(allocator); \
734
+ }
735
+
736
+ } // namespace at
737
+ C10_DIAGNOSTIC_POP()
.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <c10/core/TensorOptions.h>
2
+
3
+ namespace c10::impl {
4
+
5
+ inline std::optional<MemoryFormat>
6
+ check_tensor_options_and_extract_memory_format(
7
+ const TensorOptions& options,
8
+ std::optional<MemoryFormat> memory_format) {
9
+ TORCH_CHECK(
10
+ options.requires_grad_opt() != true,
11
+ "Operators taking TensorOptions cannot take a TensorOptions with "
12
+ "options.requires_grad set as true. This isn't implemented yet.");
13
+ TORCH_CHECK(
14
+ !(options.has_memory_format() && memory_format.has_value()),
15
+ "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
16
+ "the redundant setter.");
17
+ if (memory_format.has_value()) {
18
+ return memory_format;
19
+ } else {
20
+ return options.memory_format_opt();
21
+ }
22
+ }
23
+
24
+ } // namespace impl namespace c10
.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Backend.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <c10/core/Layout.h>
6
+ #include <c10/core/TensorOptions.h>
7
+ #include <c10/core/Storage.h>
8
+ #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
9
+ #include <ATen/core/Generator.h>
10
+
11
+
12
+ namespace at {
13
+
14
+ class Tensor;
15
+
16
+ // This class specifies a Backend and a ScalarType. Currently, it primarily
17
+ // serves as a replacement return value for Tensor::type(). Previously,
18
+ // Tensor::type() returned Type&, but we are changing Type to not be
19
+ // dtype-specific.
20
+ class TORCH_API DeprecatedTypeProperties {
21
+ public:
22
+ DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
23
+ : backend_(backend), scalar_type_(scalar_type) {}
24
+
25
+ Backend backend() const {
26
+ return backend_;
27
+ }
28
+
29
+ Layout layout() const {
30
+ return layout_from_backend(backend_);
31
+ }
32
+
33
+ bool is_sparse() const {
34
+ return layout_from_backend(backend()) == kSparse;
35
+ }
36
+
37
+ bool is_sparse_csr() const {
38
+ return layout_from_backend(backend()) == kSparseCsr;
39
+ }
40
+
41
+ c10::DeviceType device_type() const {
42
+ return backendToDeviceType(backend_);
43
+ }
44
+
45
+ bool is_cuda() const {
46
+ return backendToDeviceType(backend_) == kCUDA;
47
+ }
48
+
49
+ ScalarType scalarType() const {
50
+ return scalar_type_;
51
+ }
52
+
53
+ caffe2::TypeMeta typeMeta() const {
54
+ return scalarTypeToTypeMeta(scalar_type_);
55
+ }
56
+
57
+ bool operator==(const DeprecatedTypeProperties& other) const {
58
+ return backend_ == other.backend() && scalar_type_ == other.scalarType();
59
+ }
60
+
61
+ bool operator!=(const DeprecatedTypeProperties& other) const {
62
+ return !(*this == other);
63
+ }
64
+
65
+ std::string toString() const {
66
+ std::string base_str;
67
+ if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) {
68
+ base_str = "UndefinedType";
69
+ } else {
70
+ base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type";
71
+ }
72
+ return base_str;
73
+ }
74
+
75
+ DeprecatedTypeProperties & toBackend(Backend b) const {
76
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
77
+ b, scalar_type_);
78
+ }
79
+
80
+ DeprecatedTypeProperties & toScalarType(ScalarType s) const {
81
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
82
+ backend_, s);
83
+ }
84
+
85
+ DeprecatedTypeProperties & cpu() const {
86
+ return toBackend(Backend::CPU);
87
+ }
88
+
89
+ DeprecatedTypeProperties & cuda() const {
90
+ return toBackend(Backend::CUDA);
91
+ }
92
+
93
+ DeprecatedTypeProperties & hip() const {
94
+ return toBackend(Backend::HIP);
95
+ }
96
+
97
+ DeprecatedTypeProperties & privateUser1() const {
98
+ return toBackend(Backend::PrivateUse1);
99
+ }
100
+
101
+ /// Constructs the `TensorOptions` from a type and a `device_index`.
102
+ TensorOptions options(int16_t device_index = -1) const {
103
+ return TensorOptions().dtype(typeMeta())
104
+ .device(device_type(), static_cast<c10::DeviceIndex>(device_index))
105
+ .layout(layout());
106
+ }
107
+
108
+ /// Constructs the `TensorOptions` from a type and a Device. Asserts that
109
+ /// the device type matches the device type of the type.
110
+ TensorOptions options(std::optional<Device> device_opt) const {
111
+ if (!device_opt.has_value()) {
112
+ return options(-1);
113
+ } else {
114
+ Device device = device_opt.value();
115
+ AT_ASSERT(device.type() == device_type());
116
+ return options(device.index());
117
+ }
118
+ }
119
+
120
+ operator TensorOptions() const {
121
+ return options();
122
+ }
123
+
124
+ int64_t id() const {
125
+ return static_cast<int64_t>(backend()) *
126
+ static_cast<int64_t>(ScalarType::NumOptions) +
127
+ static_cast<int64_t>(scalarType());
128
+ }
129
+
130
+ Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
131
+ Storage unsafeStorageFromTH(void * th_pointer, bool retain) const;
132
+ Tensor copy(const Tensor & src, bool non_blocking=false, std::optional<Device> to_device={}) const;
133
+
134
+ private:
135
+ Backend backend_;
136
+ ScalarType scalar_type_;
137
+ };
138
+
139
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // In order to preserve bc, we make DeprecatedTypeProperties instances unique
4
+ // just like they are for Type.
5
+
6
+ #include <c10/core/Backend.h>
7
+ #include <c10/core/ScalarType.h>
8
+ #include <memory>
9
+
10
+ namespace at {
11
+
12
+ class DeprecatedTypeProperties;
13
+
14
+ struct TORCH_API DeprecatedTypePropertiesDeleter {
15
+ void operator()(DeprecatedTypeProperties * ptr);
16
+ };
17
+
18
+ class TORCH_API DeprecatedTypePropertiesRegistry {
19
+ public:
20
+ DeprecatedTypePropertiesRegistry();
21
+
22
+ DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) const;
23
+
24
+ private:
25
+ // NOLINTNEXTLINE(*c-array*)
26
+ std::unique_ptr<DeprecatedTypeProperties> registry
27
+ [static_cast<int>(Backend::NumOptions)]
28
+ [static_cast<int>(ScalarType::NumOptions)];
29
+ };
30
+
31
+ TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
32
+
33
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
4
+ #include <c10/macros/Export.h>
5
+ #include <c10/util/TypeTraits.h>
6
+ #include <c10/util/TypeList.h>
7
+ #include <c10/util/intrusive_ptr.h>
8
+ #include <c10/util/order_preserving_flat_hash_map.h>
9
+ #include <optional>
10
+ #include <ATen/core/TensorBody.h>
11
+ #include <ATen/core/jit_type_base.h>
12
+
13
+ namespace c10 {
14
+ struct IValue;
15
+ template<class Key, class Value> class Dict;
16
+ struct Type;
17
+
18
+ namespace impl {
19
+
20
+ using valid_dict_key_types = guts::typelist::typelist<
21
+ int64_t,
22
+ std::string,
23
+ double,
24
+ c10::complex<double>,
25
+ bool,
26
+ at::Tensor
27
+ >;
28
+ }
29
+
30
+ namespace detail {
31
+
32
+ struct DictKeyHash {
33
+ size_t operator()(const IValue& ivalue) const;
34
+ };
35
+
36
+ struct DictKeyEqualTo {
37
+ bool operator()(const IValue& lhs, const IValue& rhs) const;
38
+ };
39
+
40
+ struct DictImpl final : public c10::intrusive_ptr_target {
41
+ using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
42
+ struct DictElementTypes final {
43
+ TypePtr keyType;
44
+ TypePtr valueType;
45
+ };
46
+
47
+ explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_)
48
+ : dict(std::move(dict_))
49
+ , elementTypes(std::move(elementTypes_)) {}
50
+ dict_map_type dict;
51
+
52
+ DictElementTypes elementTypes;
53
+
54
+ intrusive_ptr<DictImpl> copy() const;
55
+ friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs);
56
+ };
57
+
58
+ }
59
+
60
+ namespace impl {
61
+ template<class Key, class Value, class Iterator> class DictIterator;
62
+
63
+ /**
64
+ * A reference to an entry in the Dict.
65
+ * Use the `key()` and `value()` methods to read the element.
66
+ */
67
+ template<class Key, class Value, class Iterator>
68
+ class DictEntryRef final {
69
+ public:
70
+ explicit DictEntryRef(Iterator iterator)
71
+ : iterator_(std::move(iterator)) {}
72
+
73
+ decltype(auto) key() const {
74
+ return iterator_->first.template to<Key>();
75
+ }
76
+
77
+ decltype(auto) value() const {
78
+ return iterator_->second.template to<Value>();
79
+ }
80
+
81
+ template<class Value_>
82
+ void setValue(Value_&& value) const {
83
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of setValue()");
84
+ iterator_->second = Value(std::forward<Value_>(value));
85
+ }
86
+ ~DictEntryRef() = default;
87
+
88
+ private:
89
+ // allow copying and moving, but only our friends (i.e. the Dict class) can do
90
+ // it. Copying/moving this reference wrapper would be too ambiguous to allow it
91
+ // in the public API.
92
+ DictEntryRef(const DictEntryRef&) = default;
93
+ DictEntryRef& operator=(const DictEntryRef&) = default;
94
+ DictEntryRef(DictEntryRef&&) noexcept = default;
95
+ DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default;
96
+
97
+ Iterator iterator_;
98
+ friend class DictIterator<Key, Value, Iterator>;
99
+ friend class Dict<Key, Value>;
100
+ };
101
+
102
+ // this wraps map_type::iterator to make sure user code can't rely
103
+ // on it being the type of the underlying map.
104
+ template<class Key, class Value, class Iterator>
105
+ class DictIterator final {
106
+ public:
107
+ // C++17 friendly std::iterator implementation
108
+ using iterator_category = std::forward_iterator_tag;
109
+ using value_type = DictEntryRef<Key, Value, Iterator>;
110
+ using difference_type = std::ptrdiff_t;
111
+ using pointer = value_type*;
112
+ using reference = value_type&;
113
+
114
+ explicit DictIterator() = default;
115
+ ~DictIterator() = default;
116
+
117
+ DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
118
+ DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
119
+ DictIterator& operator=(const DictIterator& rhs) = default;
120
+ DictIterator& operator=(DictIterator&& rhs) noexcept {
121
+ entryRef_ = std::move(rhs.entryRef_);
122
+ return *this;
123
+ }
124
+
125
+ DictIterator& operator++() {
126
+ ++entryRef_.iterator_;
127
+ return *this;
128
+ }
129
+
130
+ DictIterator operator++(int) {
131
+ DictIterator copy(*this);
132
+ ++*this;
133
+ return copy;
134
+ }
135
+
136
+ const DictEntryRef<Key, Value, Iterator>& operator*() const {
137
+ return entryRef_;
138
+ }
139
+
140
+ const DictEntryRef<Key, Value, Iterator>* operator->() const {
141
+ return &entryRef_;
142
+ }
143
+
144
+ friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) {
145
+ return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_;
146
+ }
147
+
148
+ private:
149
+ explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {}
150
+
151
+ const Iterator& get_iterator_() const {
152
+ return entryRef_.iterator_;
153
+ }
154
+
155
+ friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) {
156
+ return lhs.get_iterator_() == rhs.get_iterator_();
157
+ }
158
+
159
+ friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) {
160
+ return lhs.get_iterator_() != rhs.get_iterator_();
161
+ }
162
+
163
+ friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) {
164
+ return lhs.get_iterator_() < rhs.get_iterator_();
165
+ }
166
+
167
+ friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) {
168
+ return lhs.get_iterator_() <= rhs.get_iterator_();
169
+ }
170
+
171
+ friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) {
172
+ return lhs.get_iterator_() > rhs.get_iterator_();
173
+ }
174
+
175
+ friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) {
176
+ return lhs.get_iterator_() >= rhs.get_iterator_();
177
+ }
178
+
179
+ DictEntryRef<Key, Value, Iterator> entryRef_;
180
+
181
+ friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>;
182
+ friend class Dict<Key, Value>;
183
+ };
184
+
185
+ template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict);
186
+ template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict);
187
+ }
188
+
189
+ /**
190
+ * An object of this class stores a map from Key to Value.
191
+ *
192
+ * This is a pointer type. After a copy, both Dicts
193
+ * will share the same storage:
194
+ *
195
+ * > Dict<int, string> a;
196
+ * > Dict<int, string> b = a;
197
+ * > b.insert(3, "three");
198
+ * > ASSERT("three" == a.at(3));
199
+ *
200
+ * We use this class in the PyTorch kernel API because that
201
+ * allows us to do optimizations and switch out the underlying
202
+ * map implementation without breaking backwards compatibility
203
+ * for the kernel API.
204
+ */
205
+ template<class Key, class Value>
206
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
207
+ class Dict final {
208
+ private:
209
+ static_assert((std::is_same_v<IValue, Key> && std::is_same_v<IValue, Value>) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string.");
210
+
211
+ // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map.
212
+ // We intentionally don't offer conversion from/to
213
+ // order_preserving_flat_hash_map, return references to it or something like that,
214
+ // because such operations would get expensive if we switch out
215
+ // the actual map implementation.
216
+ // This is an intrusive_ptr because Dict is a pointer type.
217
+ // Invariant: This will never be a nullptr, there will always be a valid
218
+ // DictImpl.
219
+ c10::intrusive_ptr<detail::DictImpl> impl_;
220
+
221
+ explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl);
222
+ friend struct IValue;
223
+ template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>);
224
+ template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>);
225
+
226
+ public:
227
+ using key_type = Key;
228
+ using mapped_type = Value;
229
+ using size_type = typename detail::DictImpl::dict_map_type::size_type;
230
+ using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
231
+
232
+ /**
233
+ * Creates an empty dict.
234
+ */
235
+ explicit Dict();
236
+
237
+ /**
238
+ * Create a generic dict with runtime type information.
239
+ * This only works for c10::impl::GenericDict and is not part of the public API
240
+ * but only supposed to be used internally by PyTorch.
241
+ */
242
+ explicit Dict(TypePtr keyType, TypePtr valueType);
243
+
244
+ ~Dict() = default;
245
+
246
+ Dict(const Dict&) = default;
247
+ Dict& operator=(const Dict&) = default;
248
+
249
+ /**
250
+ * Create a new Dict pointing to a deep copy of the same data.
251
+ * The Dict returned is a new dict with separate storage.
252
+ * Changes in it are not reflected in the original dict or vice versa.
253
+ */
254
+ Dict copy() const;
255
+
256
+ /**
257
+ * Returns an iterator to the first element of the container.
258
+ * If the container is empty, the returned iterator will be equal to end().
259
+ */
260
+ iterator begin() const;
261
+
262
+ /**
263
+ * Returns an iterator to the element following the last element of the container.
264
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
265
+ */
266
+ iterator end() const;
267
+
268
+ /**
269
+ * Checks if the container has no elements.
270
+ */
271
+ bool empty() const;
272
+
273
+ /**
274
+ * Returns the number of elements in the container.
275
+ */
276
+ size_type size() const;
277
+
278
+ /**
279
+ * Erases all elements from the container. After this call, size() returns zero.
280
+ * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators.
281
+ */
282
+ void clear() const;
283
+
284
+ /**
285
+ * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key.
286
+ * May invalidate any references, pointers, or iterators referring to contained elements.
287
+ *
288
+ * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place.
289
+ */
290
+ template<class Key_, class Value_>
291
+ std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const;
292
+
293
+ /**
294
+ * If an element with the given key already exists, it is overwritten with the given value.
295
+ * Otherwise, a new element with the given key and value are inserted.
296
+ * May invalidate any references, pointers, or iterators referring to contained elements.
297
+ *
298
+ * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated.
299
+ */
300
+ template<class Key_, class Value_>
301
+ std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const;
302
+
303
+ /**
304
+ * Removes the element pointed to by iter.
305
+ * May invalidate any references, pointers, or iterators referring to contained elements.
306
+ * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter.
307
+ */
308
+ void erase(iterator iter) const;
309
+
310
+ /**
311
+ * Removes the element with the given key, if it exists.
312
+ * May invalidate any references, pointers, or iterators referring to contained elements.
313
+ *
314
+ * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't.
315
+ */
316
+ [[nodiscard]] size_t erase(const Key& key) const;
317
+
318
+ /**
319
+ * Returns the mapped value of the element with key equivalent to key.
320
+ * If no such element exists, an exception of type std::out_of_range is thrown.
321
+ */
322
+ Value at(const Key& key) const;
323
+
324
+ /**
325
+ * Finds an element with key equivalent to key.
326
+ *
327
+ * @return Iterator to an element with key equivalent to key.
328
+ * If no such element is found, past-the-end (see end()) iterator is returned.
329
+ */
330
+ iterator find(const Key& key) const;
331
+
332
+ /**
333
+ * Checks if there is an element with key equivalent to key in the container.
334
+ *
335
+ * @return true if there is such an element, otherwise false.
336
+ */
337
+ bool contains(const Key& key) const;
338
+
339
+ /**
340
+ * Increase the capacity so that at least count elements can be stored without
341
+ * having to reallocate or rehash.
342
+ */
343
+ void reserve(size_type count) const;
344
+
345
+ /**
346
+ * Value equality comparison. This function implements Python-like semantics for
347
+ * equality: two dicts with the same identity (e.g. same pointer) trivially
348
+ * compare equal, otherwise each element is compared for equality.
349
+ */
350
+ template <class Key_, class Value_>
351
+ friend bool operator==(
352
+ const Dict<Key_, Value_>& lhs,
353
+ const Dict<Key_, Value_>& rhs);
354
+ template <class Key_, class Value_>
355
+ friend bool operator!=(
356
+ const Dict<Key_, Value_>& lhs,
357
+ const Dict<Key_, Value_>& rhs);
358
+
359
+ /**
360
+ * Identity comparison. Returns true if and only if `rhs` represents the same
361
+ * Dict object as `this`.
362
+ */
363
+ bool is(const Dict& rhs) const;
364
+
365
+ // private API for now because the return type will change to TypePtr
366
+ // instead of std::optional<TypePtr> once types are mandatory.
367
+ TypePtr keyType() const;
368
+ TypePtr valueType() const;
369
+
370
+ // [unsafe set type]
371
+ // These functions mutate the tagged type of this dictionary in place.
372
+ // There is no checking that the members of the dictionary are instances
373
+ // of the new types, nor is there a check that other IValues which
374
+ // hold references to this dictionary have the right static type.
375
+ // This functionality is used only in the unpickler, where at
376
+ // creation type the real type of the dictionary is unknown, but
377
+ // then later recovered from the static type information of the
378
+ // unpickled object.
379
+ void unsafeSetKeyType(TypePtr t);
380
+ void unsafeSetValueType(TypePtr t);
381
+ };
382
+
383
+ namespace impl {
384
+ // GenericDict is how IValue stores dicts. It is, however, not part of the
385
+ // public API. Kernels should use Dicts with concrete Key, Value types instead
386
+ // (maybe except for some internal prim ops).
387
+ using GenericDict = Dict<IValue, IValue>;
388
+
389
+ }
390
+ }
391
+
392
+ namespace torch {
393
+ template<class Key, class Value> using Dict = c10::Dict<Key, Value>;
394
+ }
395
+
396
+ #include <ATen/core/Dict_inl.h> // IWYU pragma: keep
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue.h>
4
+ #include <c10/util/hash.h>
5
+
6
+ namespace c10 {
7
+ namespace detail {
8
+ inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
9
+ if (lhs.isTensor() && rhs.isTensor()) {
10
+ // for tensors, we compare only by identity (following how it's done in Python).
11
+ return lhs.is(rhs);
12
+ }
13
+ // Otherwise, we first compare by identity for efficiency, then by value (see:
14
+ // [container equality])
15
+ return _fastEqualsForContainer(lhs, rhs);
16
+ }
17
+ }
18
+
19
+ template<class T> decltype(auto) getTypePtr();
20
+ std::string toString(const Type& type);
21
+
22
+ namespace impl {
23
+
24
+ template<class Key, class Value>
25
+ Dict<Key, Value> toTypedDict(GenericDict dict) {
26
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
27
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
28
+
29
+ return Dict<Key, Value>(std::move(dict.impl_));
30
+ }
31
+
32
+ template<class Key, class Value>
33
+ GenericDict toGenericDict(Dict<Key, Value> dict) {
34
+ return GenericDict(std::move(dict.impl_));
35
+ }
36
+ }
37
+
38
+ namespace detail {
39
+
40
+ inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
41
+ if (ivalue.isInt()) {
42
+ return std::hash<int64_t>()(ivalue.toInt());
43
+ } else if (ivalue.isString()) {
44
+ return std::hash<std::string_view>()(ivalue.toStringView());
45
+ } else if (ivalue.isDouble()) {
46
+ return std::hash<double>()(ivalue.toDouble());
47
+ } else if (ivalue.isComplexDouble()) {
48
+ return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
49
+ } else if (ivalue.isBool()) {
50
+ return std::hash<bool>()(ivalue.toBool());
51
+ } else if (ivalue.isTensor()) {
52
+ return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
53
+ } else if (ivalue.isDevice()) {
54
+ return std::hash<Device>()(ivalue.toDevice());
55
+ } else {
56
+ TORCH_CHECK(false, "Can't hash IValues with tag '", ivalue.tagKind(), "'");
57
+ }
58
+ }
59
+
60
+ inline intrusive_ptr<DictImpl> DictImpl::copy() const {
61
+ return make_intrusive<DictImpl>(dict, elementTypes);
62
+ }
63
+
64
+ }
65
+
66
+ template<class Key, class Value>
67
+ Dict<Key, Value>::Dict()
68
+ :Dict(make_intrusive<detail::DictImpl>(
69
+ detail::DictImpl::dict_map_type(),
70
+ detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
71
+ static_assert(!std::is_same_v<Key, IValue>, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
72
+ static_assert(!std::is_same_v<Value, IValue>, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
73
+ }
74
+
75
+ template<class Key, class Value>
76
+ Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
77
+ : Dict(make_intrusive<detail::DictImpl>(
78
+ detail::DictImpl::dict_map_type(),
79
+ detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
80
+ static_assert(std::is_same_v<Key, IValue>, "This constructor is only valid for c10::impl::GenericDict.");
81
+ static_assert(std::is_same_v<Value, IValue>, "This constructor is only valid for c10::impl::GenericDict.");
82
+ }
83
+
84
+ template<class Key, class Value>
85
+ Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
86
+
87
+ template<class Key, class Value>
88
+ Dict<Key, Value> Dict<Key, Value>::copy() const {
89
+ return Dict<Key, Value>(impl_->copy());
90
+ }
91
+
92
+ template<class Key, class Value>
93
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
94
+ return iterator{impl_->dict.begin()};
95
+ }
96
+
97
+ template<class Key, class Value>
98
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
99
+ return iterator{impl_->dict.end()};
100
+ }
101
+
102
+ template<class Key, class Value>
103
+ bool Dict<Key, Value>::empty() const {
104
+ return impl_->dict.empty();
105
+ }
106
+
107
+ template<class Key, class Value>
108
+ typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
109
+ return impl_->dict.size();
110
+ }
111
+
112
+ template<class Key, class Value>
113
+ void Dict<Key, Value>::clear() const {
114
+ impl_->dict.clear();
115
+ }
116
+
117
+ template<class Key, class Value>
118
+ template<class Key_, class Value_>
119
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
120
+ static_assert(std::is_constructible_v<Key, Key_>, "Wrong type for the key argument of Dict::insert");
121
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of Dict::insert");
122
+ auto inserted = impl_->dict.emplace(
123
+ Key(std::forward<Key_>(key)),
124
+ Value(std::forward<Value_>(value)));
125
+ return {iterator{inserted.first}, inserted.second};
126
+ }
127
+
128
+ template<class Key, class Value>
129
+ template<class Key_, class Value_>
130
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
131
+ static_assert(std::is_constructible_v<Key, Key_>, "Wrong type for the key argument of Dict::insert_or_assign");
132
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of Dict::insert_or_assign");
133
+ auto inserted = impl_->dict.insert_or_assign(
134
+ Key(std::forward<Key_>(key)),
135
+ Value(std::forward<Value_>(value)));
136
+ return {iterator{inserted.first}, inserted.second};
137
+ }
138
+
139
+ template<class Key, class Value>
140
+ void Dict<Key, Value>::erase(iterator iter) const {
141
+ impl_->dict.erase(iter.entryRef_.iterator_);
142
+ }
143
+
144
+ template <class Key, class Value>
145
+ [[nodiscard]] size_t Dict<Key, Value>::erase(const Key& key) const {
146
+ return impl_->dict.erase(key);
147
+ }
148
+
149
+ template<class Key, class Value>
150
+ Value Dict<Key, Value>::at(const Key& key) const {
151
+ return impl_->dict.at(key).template to<Value>();
152
+ }
153
+
154
+ template<class Key, class Value>
155
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
156
+ return iterator{impl_->dict.find(key)};
157
+ }
158
+
159
+ template<class Key, class Value>
160
+ bool Dict<Key, Value>::contains(const Key& key) const {
161
+ return end() != find(key);
162
+ }
163
+
164
+ template<class Key, class Value>
165
+ void Dict<Key, Value>::reserve(size_type count) const {
166
+ impl_->dict.reserve(count);
167
+ }
168
+
169
+ template<class Key, class Value>
170
+ TypePtr Dict<Key, Value>::keyType() const {
171
+ return impl_->elementTypes.keyType;
172
+ }
173
+
174
+ template<class Key, class Value>
175
+ TypePtr Dict<Key, Value>::valueType() const {
176
+ return impl_->elementTypes.valueType;
177
+ }
178
+ template <class Key, class Value>
179
+ void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
180
+ impl_->elementTypes.keyType = std::move(t);
181
+ }
182
+
183
+ template <class Key, class Value>
184
+ void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
185
+ impl_->elementTypes.valueType = std::move(t);
186
+ }
187
+
188
+ template <class Key_, class Value_>
189
+ bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
190
+ // Dicts with the same identity trivially compare equal.
191
+ if (lhs.impl_ == rhs.impl_) {
192
+ return true;
193
+ }
194
+
195
+ // Otherwise compare the values
196
+ return *lhs.impl_ == *rhs.impl_;
197
+ }
198
+
199
+ template <class Key_, class Value_>
200
+ bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
201
+ return !(lhs == rhs);
202
+ }
203
+
204
+ template <class Key, class Value>
205
+ bool Dict<Key, Value>::is(const Dict& rhs) const {
206
+ return this->impl_ == rhs.impl_;
207
+ }
208
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/util/DimVector.h>
3
+
4
+ namespace at {
5
+
6
+ // Re-declaring 'DimVector' type and size inside 'at' namespace.
7
+ // This is done to avoid modifying every use into their 'c10'
8
+ // equivalent.
9
+
10
+ using c10::kDimVectorStaticSize;
11
+ using c10::DimVector;
12
+
13
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/symbol.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <optional>
6
+ #include <ostream>
7
+
8
+ namespace at {
9
+
10
+ enum class NameType: uint8_t { BASIC, WILDCARD };
11
+
12
+ struct TORCH_API Dimname {
13
+ static Dimname fromSymbol(Symbol name);
14
+ static Dimname wildcard();
15
+ static bool isValidName(const std::string& name);
16
+
17
+ NameType type() const { return type_; }
18
+ Symbol symbol() const { return name_; }
19
+
20
+ bool isBasic() const { return type_ == NameType::BASIC; }
21
+ bool isWildcard() const { return type_ == NameType::WILDCARD; }
22
+
23
+ bool matches(Dimname other) const;
24
+ std::optional<Dimname> unify(Dimname other) const;
25
+
26
+ private:
27
+ Dimname(Symbol name)
28
+ : name_(name), type_(NameType::BASIC) {}
29
+ Dimname(Symbol name, NameType type)
30
+ : name_(name), type_(type) {}
31
+
32
+ Symbol name_;
33
+ NameType type_;
34
+ };
35
+
36
+ using DimnameList = c10::ArrayRef<Dimname>;
37
+
38
+ TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
39
+
40
+ inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
41
+ return lhs.symbol() == rhs.symbol();
42
+ }
43
+
44
+ inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
45
+ return !(lhs == rhs);
46
+ }
47
+
48
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TransformationHelper.h>
4
+ #include <c10/util/Half.h>
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/MathConstants.h>
7
+ #include <c10/macros/Macros.h>
8
+
9
+ #include <cmath>
10
+ #include <limits>
11
+ #include <optional>
12
+ #include <type_traits>
13
+
14
+ /**
15
+ * Distributions kernel adapted from THRandom.cpp
16
+ * The kernels try to follow std::random distributions signature
17
+ * For instance: in ATen
18
+ * auto gen = at::detail::createCPUGenerator();
19
+ * at::uniform_real_distribution<double> uniform(0, 1);
20
+ * auto sample = uniform(gen.get());
21
+ *
22
+ * vs std::random
23
+ *
24
+ * std::mt19937 gen;
25
+ * std::uniform_real_distribution uniform(0, 1);
26
+ * auto sample = uniform(gen);
27
+ */
28
+
29
+
30
+ namespace at {
31
+ namespace {
32
+
33
+ /**
34
+ * Samples a discrete uniform distribution in the range [base, base+range) of type T
35
+ */
36
+ template <typename T>
37
+ struct uniform_int_from_to_distribution {
38
+
39
+ C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {}
40
+
41
+ template <typename RNG>
42
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
43
+ #ifdef FBCODE_CAFFE2
44
+ if ((
45
+ std::is_same_v<T, int64_t> ||
46
+ std::is_same_v<T, double> ||
47
+ std::is_same_v<T, float> ||
48
+ std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
49
+ #else
50
+ if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
51
+ #endif
52
+ {
53
+ return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
54
+ } else {
55
+ return transformation::uniform_int_from_to<T>(generator->random(), range_, base_);
56
+ }
57
+ }
58
+
59
+ private:
60
+ uint64_t range_;
61
+ int64_t base_;
62
+ };
63
+
64
+ /**
65
+ * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
66
+ */
67
+ template <typename T>
68
+ struct uniform_int_full_range_distribution {
69
+
70
+ template <typename RNG>
71
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
72
+ return transformation::uniform_int_full_range<T>(generator->random64());
73
+ }
74
+
75
+ };
76
+
77
+ /**
78
+ * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
79
+ * and [0, 2^mantissa] for floating-point types.
80
+ */
81
+ template <typename T>
82
+ struct uniform_int_distribution {
83
+
84
+ template <typename RNG>
85
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
86
+ if constexpr (std::is_same_v<T, double> || std::is_same_v<T, int64_t>) {
87
+ return transformation::uniform_int<T>(generator->random64());
88
+ } else {
89
+ return transformation::uniform_int<T>(generator->random());
90
+ }
91
+ }
92
+
93
+ };
94
+
95
+ /**
96
+ * Samples a uniform distribution in the range [from, to) of type T
97
+ */
98
+ template <typename T>
99
+ struct uniform_real_distribution {
100
+
101
+ C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) : from_(from), to_(to) {
102
+ TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
103
+ TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max());
104
+ }
105
+
106
+ template <typename RNG>
107
+ C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
108
+ if constexpr (std::is_same_v<T, double>) {
109
+ return transformation::uniform_real<T>(generator->random64(), from_, to_);
110
+ } else {
111
+ return transformation::uniform_real<T>(generator->random(), from_, to_);
112
+ }
113
+ }
114
+
115
+ private:
116
+ T from_;
117
+ T to_;
118
+ };
119
+
120
+ // The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
121
+ // https://github.com/pytorch/pytorch/issues/40052
122
+ #define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \
123
+ template <typename T> \
124
+ struct has_member_##member \
125
+ { \
126
+ typedef char yes; \
127
+ typedef long no; \
128
+ template <typename U> static yes test(decltype(&U::member)); \
129
+ template <typename U> static no test(...); \
130
+ static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
131
+ }
132
+
133
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
134
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
135
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
136
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
137
+
138
+ #define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \
139
+ \
140
+ template <typename RNG, typename ret_type, \
141
+ typename std::enable_if_t<( \
142
+ has_member_next_##TYPE##_normal_sample<RNG>::value && \
143
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
144
+ ), int> = 0> \
145
+ C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
146
+ if (generator->next_##TYPE##_normal_sample()) { \
147
+ *ret = *(generator->next_##TYPE##_normal_sample()); \
148
+ generator->set_next_##TYPE##_normal_sample(std::optional<TYPE>()); \
149
+ return true; \
150
+ } \
151
+ return false; \
152
+ } \
153
+ \
154
+ template <typename RNG, typename ret_type, \
155
+ typename std::enable_if_t<( \
156
+ !has_member_next_##TYPE##_normal_sample<RNG>::value || \
157
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
158
+ ), int> = 0> \
159
+ C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) { \
160
+ return false; \
161
+ } \
162
+ \
163
+ template <typename RNG, typename ret_type, \
164
+ typename std::enable_if_t<( \
165
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
166
+ ), int> = 0> \
167
+ C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
168
+ generator->set_next_##TYPE##_normal_sample(cache); \
169
+ } \
170
+ \
171
+ template <typename RNG, typename ret_type, \
172
+ typename std::enable_if_t<( \
173
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
174
+ ), int> = 0> \
175
+ C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \
176
+ }
177
+
178
+ DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double)
179
+ DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float)
180
+
181
+ /**
182
+ * Samples a normal distribution using the Box-Muller method
183
+ * Takes mean and standard deviation as inputs
184
+ * Note that Box-muller method returns two samples at a time.
185
+ * Hence, we cache the "next" sample in the CPUGeneratorImpl class.
186
+ */
187
+ template <typename T>
188
+ struct normal_distribution {
189
+
190
+ C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) {
191
+ TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in);
192
+ }
193
+
194
+ template <typename RNG>
195
+ C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
196
+ dist_acctype<T> ret;
197
+ // return cached values if available
198
+ if constexpr (std::is_same_v<T, double>) {
199
+ if (maybe_get_next_double_normal_sample(generator, &ret)) {
200
+ return transformation::normal(ret, mean, stdv);
201
+ }
202
+ } else {
203
+ if (maybe_get_next_float_normal_sample(generator, &ret)) {
204
+ return transformation::normal(ret, mean, stdv);
205
+ }
206
+ }
207
+ // otherwise generate new normal values
208
+ uniform_real_distribution<T> uniform(0.0, 1.0);
209
+ const dist_acctype<T> u1 = uniform(generator);
210
+ const dist_acctype<T> u2 = uniform(generator);
211
+ const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log1p(-u2));
212
+ const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1;
213
+ if constexpr (std::is_same_v<T, double>) {
214
+ maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
215
+ } else {
216
+ maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
217
+ }
218
+ ret = r * ::cos(theta);
219
+ return transformation::normal(ret, mean, stdv);
220
+ }
221
+
222
+ private:
223
+ T mean;
224
+ T stdv;
225
+ };
226
+
227
+ template <typename T>
228
+ struct DiscreteDistributionType { using type = float; };
229
+
230
+ template <> struct DiscreteDistributionType<double> { using type = double; };
231
+
232
+ /**
233
+ * Samples a bernoulli distribution given a probability input
234
+ */
235
+ template <typename T>
236
+ struct bernoulli_distribution {
237
+
238
+ C10_HOST_DEVICE inline bernoulli_distribution(T p_in) : p(p_in) {
239
+ TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1);
240
+ }
241
+
242
+ template <typename RNG>
243
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
244
+ uniform_real_distribution<T> uniform(0.0, 1.0);
245
+ return transformation::bernoulli<T>(uniform(generator), p);
246
+ }
247
+
248
+ private:
249
+ T p;
250
+ };
251
+
252
+ /**
253
+ * Samples a geometric distribution given a probability input
254
+ */
255
+ template <typename T>
256
+ struct geometric_distribution {
257
+
258
+ C10_HOST_DEVICE inline geometric_distribution(T p_in) : p(p_in) {
259
+ TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1);
260
+ }
261
+
262
+ template <typename RNG>
263
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
264
+ uniform_real_distribution<T> uniform(0.0, 1.0);
265
+ return transformation::geometric<T>(uniform(generator), p);
266
+ }
267
+
268
+ private:
269
+ T p;
270
+ };
271
+
272
+ /**
273
+ * Samples an exponential distribution given a lambda input
274
+ */
275
+ template <typename T>
276
+ struct exponential_distribution {
277
+
278
+ C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {}
279
+
280
+ template <typename RNG>
281
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
282
+ uniform_real_distribution<T> uniform(0.0, 1.0);
283
+ return transformation::exponential<T>(uniform(generator), lambda);
284
+ }
285
+
286
+ private:
287
+ T lambda;
288
+ };
289
+
290
+ /**
291
+ * Samples a cauchy distribution given median and sigma as inputs
292
+ */
293
+ template <typename T>
294
+ struct cauchy_distribution {
295
+
296
+ C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {}
297
+
298
+ template <typename RNG>
299
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
300
+ uniform_real_distribution<T> uniform(0.0, 1.0);
301
+ return transformation::cauchy<T>(uniform(generator), median, sigma);
302
+ }
303
+
304
+ private:
305
+ T median;
306
+ T sigma;
307
+ };
308
+
309
+ /**
310
+ * Samples a lognormal distribution
311
+ * Takes mean and standard deviation as inputs
312
+ * Outputs two samples at a time
313
+ */
314
+ template <typename T>
315
+ struct lognormal_distribution {
316
+
317
+ C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) {
318
+ TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
319
+ }
320
+
321
+ template<typename RNG>
322
+ C10_HOST_DEVICE inline T operator()(RNG generator){
323
+ normal_distribution<T> normal(mean, stdv);
324
+ return transformation::log_normal<T>(normal(generator));
325
+ }
326
+
327
+ private:
328
+ T mean;
329
+ T stdv;
330
+ };
331
+ }
332
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ostream>
4
+ #include <string>
5
+
6
+ #include <c10/core/Scalar.h>
7
+ #include <ATen/core/Tensor.h>
8
+
9
+ namespace c10 {
10
+ TORCH_API std::ostream& operator<<(std::ostream& out, Backend b);
11
+ TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s);
12
+ TORCH_API std::string toString(const Scalar& s);
13
+ }
14
+ namespace at {
15
+
16
+ TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
17
+ TORCH_API std::ostream& print(
18
+ std::ostream& stream,
19
+ const Tensor& tensor,
20
+ int64_t linesize);
21
+ inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
22
+ return print(out,t,80);
23
+ }
24
+ TORCH_API void print(const Tensor & t, int64_t linesize=80);
25
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <deque>
5
+ #include <mutex>
6
+ #include <utility>
7
+
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <c10/core/Device.h>
11
+ #include <c10/core/DispatchKeySet.h>
12
+
13
+ // For the record I don't think this is a correct pimpl idiom.
14
+ // Including Impl header in interface header defeats the purpose
15
+ // because you can't change Impl private members without forcing
16
+ // everything that included the interface to rebuild.
17
+ // Impl should be forward-declared in the interface header instead.
18
+ #include <c10/core/GeneratorImpl.h>
19
+
20
+ /**
21
+ * Note [Generator]
22
+ * ~~~~~~~~~~~~~~~~
23
+ * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
24
+ * generate a seemingly random sequence of numbers, that may be later be used in creating
25
+ * a random distribution. Such an engine almost always maintains a state and requires a
26
+ * seed to start off the creation of random numbers. Often times, users have
27
+ * found it beneficial to be able to explicitly create, retain, and destroy
28
+ * PRNG states and also be able to have control over the seed value.
29
+ *
30
+ * A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
31
+ * For instance, it does so by letting users seed a PRNG engine, fork the state of the
32
+ * engine, etc.
33
+ *
34
+ * By default, there is one generator per device, and a device's generator is
35
+ * lazily created. A user can use the torch.Generator() api to create their own generator.
36
+ */
37
+
38
+ /**
39
+ * Note [Acquire lock when using random generators]
40
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41
+ * Generator and its derived classes are NOT thread-safe. Please note that most of the
42
+ * places where we have inserted locking for generators are historically based, and we
43
+ * haven't actually checked that everything is truly thread safe (and it probably isn't).
44
+ * Please use the public mutex_ when using any methods from these classes, except for the
45
+ * read-only methods. You can learn about the usage by looking into the unittests
46
+ * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
47
+ *
48
+ * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
49
+ * them non-thread safe and instead making the generator state splittable, to accommodate
50
+ * forks into other threads).
51
+ */
52
+
53
+ namespace at {
54
+
55
+ class Tensor;
56
+
57
+ struct TORCH_API Generator {
58
+ Generator() = default;
59
+
60
+ explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
61
+ : impl_(std::move(gen_impl)) {
62
+ if (impl_.get() == nullptr) {
63
+ throw std::runtime_error("GeneratorImpl with nullptr is not supported");
64
+ }
65
+ }
66
+
67
+ bool operator==(const Generator& rhs) const {
68
+ return this->impl_ == rhs.impl_;
69
+ }
70
+
71
+ bool operator!=(const Generator& rhs) const {
72
+ return !((*this) == rhs);
73
+ }
74
+
75
+ bool defined() const {
76
+ return static_cast<bool>(impl_);
77
+ }
78
+
79
+ c10::GeneratorImpl* unsafeGetGeneratorImpl() const {
80
+ return impl_.get();
81
+ }
82
+
83
+ c10::GeneratorImpl* unsafeReleaseGeneratorImpl() {
84
+ return impl_.release();
85
+ }
86
+
87
+ const c10::intrusive_ptr<c10::GeneratorImpl>& getIntrusivePtr() const {
88
+ return impl_;
89
+ }
90
+
91
+ void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
92
+ // Sets the offset of Generator state to the desired offset. This is currently
93
+ // supported for only Philox based Generators, i.e., CUDA and MPS.
94
+ void set_offset(uint64_t offset) { impl_->set_offset(offset); }
95
+
96
+ // Returns the offset of Generator state. This is currently supported for only
97
+ // Philox based Generators, i.e., CUDA and MPS.
98
+ uint64_t get_offset() const { return impl_->get_offset(); }
99
+
100
+ uint64_t current_seed() const { return impl_->current_seed(); }
101
+
102
+ uint64_t seed() { return impl_->seed(); }
103
+
104
+ // Implementation not inlined to prevent cycle reference between
105
+ // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
106
+ void set_state(const at::Tensor& new_state);
107
+
108
+ at::Tensor get_state() const;
109
+
110
+ void graphsafe_set_state(const Generator& new_state);
111
+
112
+ Generator graphsafe_get_state() const;
113
+
114
+ std::mutex& mutex() {
115
+ return impl_->mutex_;
116
+ }
117
+
118
+ DispatchKeySet key_set() const {
119
+ return impl_->key_set();
120
+ }
121
+
122
+ Device device() const { return impl_->device(); }
123
+
124
+ inline void set_pyobj(PyObject* pyobj) const noexcept {
125
+ impl_->set_pyobj(pyobj);
126
+ }
127
+
128
+ inline PyObject* pyobj() const noexcept {
129
+ return impl_->pyobj();
130
+ }
131
+
132
+ template<typename T>
133
+ T* get() const { return static_cast<T*>(impl_.get()); }
134
+
135
+ Generator clone() const {
136
+ return Generator(impl_->clone());
137
+ }
138
+
139
+ private:
140
+ c10::intrusive_ptr<c10::GeneratorImpl> impl_;
141
+ };
142
+
143
+ template<class Impl, class... Args>
144
+ Generator make_generator(Args&&... args) {
145
+ return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
146
+ }
147
+
148
+ /**
149
+ * Utility function to static cast input Generator* to
150
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
151
+ */
152
+ template <typename T>
153
+ inline T * check_generator(std::optional<Generator> gen) {
154
+ TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
155
+ TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
156
+ TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
157
+ return gen->get<T>();
158
+ }
159
+
160
+ /**
161
+ * Utility function used in tensor implementations, which
162
+ * supplies the default generator to tensors, if an input generator
163
+ * is not supplied. The input Generator* is also static casted to
164
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
165
+ */
166
+ template <typename T>
167
+ inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
168
+ return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
169
+ }
170
+
171
+ namespace detail {
172
+
173
+ /**
174
+ * Helper function for checking the validity of new random generator
175
+ * state. Right now following conditions are checked:
176
+ *
177
+ * - The new state tensor must be a torch.ByteTensor
178
+ * - Data of the new state tensor must be contiguous
179
+ */
180
+ inline void check_rng_state(const c10::TensorImpl& new_state) {
181
+ TORCH_CHECK_TYPE(
182
+ new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
183
+ "RNG state must be a torch.ByteTensor"
184
+ );
185
+
186
+ TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
187
+ }
188
+
189
+ } // namespace detail
190
+
191
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Generator.h>
4
+ #include <c10/util/intrusive_ptr.h>
5
+
6
+ namespace at {
7
+
8
+ using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
9
+
10
+ TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
11
+
12
+ class TORCH_API _GeneratorRegister {
13
+ public:
14
+ explicit _GeneratorRegister(const GeneratorFuncType& func);
15
+ };
16
+
17
+ TORCH_API at::Generator GetGeneratorForPrivateuse1(
18
+ c10::DeviceIndex device_index);
19
+
20
+ /**
21
+ * This is used to register Generator to PyTorch for `privateuse1` key.
22
+ *
23
+ * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
24
+ *
25
+ * class CustomGeneratorImpl : public c10::GeneratorImpl {
26
+ * CustomGeneratorImpl(DeviceIndex device_index = -1);
27
+ * explicit ~CustomGeneratorImpl() override = default;
28
+ * ...
29
+ * };
30
+ *
31
+ * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
32
+ * return at::make_generator<CustomGeneratorImpl>(id);
33
+ * }
34
+ */
35
+
36
+ #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
37
+ static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
38
+
39
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue_to.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ #include <functional>
8
+ #include <initializer_list>
9
+ #include <iterator>
10
+ #include <type_traits>
11
+
12
+ /*
13
+ * [Note: IListRef]
14
+ * Wrapper around different API containers (e.g. boxed and unboxed).
15
+ *
16
+ * What is it?
17
+ * ===========
18
+ * It is a tagged union of both boxed and unboxed API containers.
19
+ * Working implementations:
20
+ *
21
+ * - `IListRef<at::Tensor>`
22
+ * - `IListRef<at::OptionalTensorRef>`
23
+ *
24
+ * Note that `IListRef` is a view type. Meaning that it won't own the
25
+ * tensors it holds. It's intended to be used only as argument parameters.
26
+ * Specifically, where these 2 worlds overlap.
27
+ *
28
+ * What is this for?
29
+ * =================
30
+ * Historically, PyTorch has maintained 2 different APIs: the unboxed
31
+ * (called from C++ API and Python eager mode) and boxed APIs (called
32
+ * from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
33
+ *
34
+ * Calling unboxed kernels from the boxed "world" and vice-versa may
35
+ * result in non-negligible overhead. Lists are one of those types:
36
+ *
37
+ * - Boxed world: `c10::List`
38
+ * - Unboxed world: `c10::ArrayRef`
39
+ *
40
+ * In this context, `c10::IListRef` solves this problem by wrapping those
41
+ * 2 container types, so that we don't need to convert from one to
42
+ * the other.
43
+ *
44
+ * (see https://github.com/pytorch/pytorch/issues/66328)
45
+ *
46
+ * What does it do?
47
+ * ================
48
+ * This container wraps around the different tagged containers
49
+ * (currently, only boxed and unboxed), without incurring in extra
50
+ * overhead for converting from one to another. It does so while
51
+ * exposing usual container methods, which dispatch to corresponding
52
+ * implementations.
53
+ *
54
+ * While it works with different container types, it introduces
55
+ * overhead for repeatedly calling member functions (since those will
56
+ * get dispatched, again). Therefore, you should only use it to iterate
57
+ * through the list up to one time. If you need to do more complex things,
58
+ * call `materialize()` first.
59
+ *
60
+ * Adding support for a new Tag
61
+ * ============================
62
+ * Suppose we want to add a new tag: `Chest`. Here are the steps
63
+ * we would have to go through:
64
+ *
65
+ * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
66
+ *
67
+ * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
68
+ * ...
69
+ * _(Chest, ##__VA_ARGS__)
70
+ *
71
+ * 2. Add type aliases, union members, and constructors.
72
+ *
73
+ * template <typename T>
74
+ * class IListRef {
75
+ * ...
76
+ * using chest_type =
77
+ * typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type;
78
+ * ...
79
+ * IListRef(...) : tag_(IListRefTag::Chest) {
80
+ * ...
81
+ * }
82
+ * ...
83
+ * union Payload {
84
+ * ...
85
+ * chest_type chest;
86
+ * ...
87
+ * };
88
+ * ...
89
+ * };
90
+ *
91
+ * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
92
+ * preferable to make the default implementation work for `T = Tensor`
93
+ * (both `Unboxed` and `Boxed` do it).
94
+ *
95
+ * template <typename T, typename ListElemT>
96
+ * class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> {
97
+ * public:
98
+ * using elem_type = ListElemT;
99
+ * using list_type = ChestContainer<elem_type>;
100
+ *
101
+ * static const list_type& unwrap(const IListRef<T>& ilist) { ... }
102
+ *
103
+ * static typename list_type::const_iterator& unwrap(
104
+ * IListRefIterator<T>& it) { ... }
105
+ *
106
+ * static const typename list_type::const_iterator& unwrap(
107
+ * const IListRefIterator<T>& it) { ... }
108
+ *
109
+ * static IListRefConstRef<T> iterator_get(
110
+ * const typename list_type::const_iterator& it) { ... }
111
+ * }
112
+ *
113
+ * 4. Add an specialization for each of the already supported types.
114
+ * Finally, for consistency, add them to the tracking list.
115
+ * (see [Note: IListRefTagImpl Specializations])
116
+ *
117
+ * template <>
118
+ * class IListRefTagImpl<IListRefTag::Chest, at::Tensor>
119
+ * : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {};
120
+ *
121
+ * Adding support for a new Type
122
+ * =============================
123
+ * Suppose we want to add support for a new type: `Matrix`.
124
+ * Here are the steps we would have to go through:
125
+ *
126
+ * 1. Add an specialization for each of the existing tags.
127
+ * For consistency, add them to the tracking list.
128
+ * (see [Note: IListRefTagImpl Specializations])
129
+ *
130
+ * template <>
131
+ * class IListRefTagImpl<IListRefTag::Unboxed, Matrix>
132
+ * : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {};
133
+ *
134
+ * template <>
135
+ * class IListRefTagImpl<Matrix, IListRefTag::Boxed>
136
+ * : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {};
137
+ *
138
+ * Common Problems
139
+ * ===============
140
+ * 1. One of `IListRef(Iterator)` methods are failing to compile.
141
+ *
142
+ * That may be happening because the container type you added
143
+ * is not compatible with the code written for that method. If
144
+ * that's true, then you might have to transform that code into
145
+ * a static method call (see `List::operator[]` method).
146
+ *
147
+ * 2. Can't make `IListRefIterator<T>::operator*` return a const-reference.
148
+ *
149
+ * First, keep in mind that we assume that boxed containers will
150
+ * have to deal with `IValue` (e.g. `c10::List`). In this context,
151
+ * what may be happening is that `IValue` doesn't store internally
152
+ * your type `T`. Instead, it constructs a type new `T` everytime
153
+ * you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
154
+ */
155
+
156
+ namespace c10 {
157
+ template <typename T>
158
+ class IListRef;
159
+
160
+ /*
161
+ * Applies arbitrary macros to each `IListRefTag`.
162
+ */
163
+ #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
164
+ _(Unboxed, ##__VA_ARGS__) \
165
+ _(Boxed, ##__VA_ARGS__) \
166
+ _(Materialized, ##__VA_ARGS__)
167
+
168
+ /*
169
+ * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
170
+ * while bringing to scope:
171
+ *
172
+ * - `ImplT`: the implementation class for `TAG`
173
+ * - `this_`: the result of unwrapping `this`
174
+ */
175
+ #define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \
176
+ case c10::IListRefTag::TAG: { \
177
+ using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \
178
+ auto& this_ = ImplT::unwrap(*this); \
179
+ BODY \
180
+ } break;
181
+
182
+ /*
183
+ * Dispatches the unwrap call, depending on `TAG`, followed by
184
+ * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
185
+ *
186
+ * This macro is useful because it allows us to handle different
187
+ * types (that correspond to different tags) to be implemented
188
+ * only once. We can do it even when the implementation of the
189
+ * different tags aren't syntatically the same, by dispatching
190
+ * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
191
+ */
192
+ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
193
+ switch (TAG) { \
194
+ TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
195
+ break; \
196
+ default: \
197
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
198
+ }
199
+
200
+ enum class IListRefTag {
201
+ #define DEFINE_TAG(tag, ...) tag,
202
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
203
+ #undef DEFINE_TAG
204
+ None
205
+ };
206
+
207
+ namespace detail {
208
+ /*
209
+ * Type alias that specifies whether we return a reference or a copy of `T`.
210
+ *
211
+ * What is this for?
212
+ * =================
213
+ * Since values in the boxed world are represented by an `IValue`, we also
214
+ * depend on whether it can be converted to a const-reference (`Tensor`) or
215
+ * has to create a new copy of `T` (`OptionalTensorRef`).
216
+ */
217
+ template <typename T>
218
+ using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type;
219
+
220
+ /*
221
+ * Interface that implements key functions for each `IListRefTag` type.
222
+ *
223
+ * What is this for?
224
+ * =================
225
+ * Given an `IListRef(Iterator)<T>`, some methods have to be implemented
226
+ * differently for each `TAG`. Therefore, the methods inside this class
227
+ * are used as dispatch targets for the different `IListRefTag` values.
228
+ *
229
+ * You should create an specialization of this class for each possible
230
+ * combination of `IListRefTag` type (except `None`) and element types
231
+ * (e.g. `Tensor`).
232
+ *
233
+ * What does it do?
234
+ * ================
235
+ * 1. defines static methods to be used as dispatch targets by both
236
+ * `IListRef<T>` and `IListRefIterator<T>` (see the implementation of
237
+ * `IListRefTagImplBase`).
238
+ *
239
+ * 2. defines the `elem_type` and `list_type` aliases that will be
240
+ * used in the definition of `IListRef<T>`. In general, we should do
241
+ * so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`.
242
+ *
243
+ * [Note: IListRefTagImpl Specialization]
244
+ * ======================================
245
+ * For `IListRef(Iterator)<at::Tensor>`:
246
+ * - <IListRefTag::Unboxed, at::Tensor>
247
+ * - <IListRefTag::Boxed, at::Tensor>
248
+ * - <IListRefTag::Materialized, at::Tensor>
249
+ *
250
+ * For `IListRef(Iterator)<at::OptionalTensorRef>`:
251
+ * - <IListRefTag::Unboxed, at::OptionalTensorRef>
252
+ * - <IListRefTag::Boxed, at::OptionalTensorRef>
253
+ * - <IListRefTag::Materialized, at::OptionalTensorRef>
254
+ */
255
+ template <IListRefTag TAG, typename T>
256
+ class IListRefTagImpl {};
257
+
258
+ /*
259
+ * Base implementation of `IListRefTagImpl<TAG, T>` methods.
260
+ *
261
+ * What is this for?
262
+ * =================
263
+ * This should make adding specializations for new types easier. For
264
+ * example, one should be able to add a new type just by making its
265
+ * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
266
+ *
267
+ * You should create a partial specialization for this class only if
268
+ * you introduce a new `IListRefTag`. The idea being that there is one
269
+ * default implementation for each possible value of `IListRefTag`.
270
+ *
271
+ * What does it do?
272
+ * ================
273
+ * 1. defines `elem_type` as an alias to `ListElemT`.
274
+ *
275
+ * 1. defines `list_type` as an alias to the default container type
276
+ * that will hold a collection of `elem_type`. The idea being that
277
+ * all types tagged as `TAG` will have `list_type` as its container,
278
+ * with different `elem_type`.
279
+ *
280
+ * 3. defines the default implementation for each of the methods that
281
+ * are supposed to be defined on `IListRefTagImpl` specializations.
282
+ *
283
+ * 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means
284
+ * that the payload of the type `IListRef<T>` will be of type `list_type`
285
+ * when it is tagged as `TAG`.
286
+ */
287
+ template <IListRefTag TAG, typename T, typename ListElemT = T>
288
+ class IListRefTagImplBase {};
289
+
290
+ /*
291
+ * Materialized container for `IListRef<T>`.
292
+ *
293
+ * What is this for?
294
+ * =================
295
+ * Container that groups `T` references together. This exchanges the
296
+ * overhead of every method call from `IListRef<T>` for a dynamic allocation.
297
+ *
298
+ * You should use this container instead of `IListRef<T>` if:
299
+ *
300
+ * - You are going to iterate the list more than once
301
+ * - You need to repeatedly access arbitrary elements (using `operator[]`)
302
+ * What does it do?
303
+
304
+ * ================
305
+ * Removes the reference (&) from the type, and wraps it into a
306
+ * `std::reference_wrapper`. If `IListRefConstRef<T>` is not a
307
+ * reference type, then it's left unchanged.
308
+ */
309
+ template <typename T>
310
+ using _MaterializedIListRefElem = std::conditional_t<
311
+ std::is_reference_v<T>,
312
+ typename std::reference_wrapper<std::remove_reference_t<T>>,
313
+ T>;
314
+
315
+ template <typename T>
316
+ using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>;
317
+
318
+ template <typename T>
319
+ using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>;
320
+
321
+ } // namespace detail
322
+
323
+ /*
324
+ * Iterator for `IListRef<T>`.
325
+ *
326
+ * What is it?
327
+ * ===========
328
+ * Currently, a `std::bidirectional_iterator` that wraps the iterator
329
+ * types defined for each of the `IListRefTag`.
330
+ *
331
+ * One should be able to use it, as if it were the unwrapped
332
+ * iterators themselves.
333
+
334
+ * What does it do?
335
+ * ================
336
+ * Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it
337
+ * wraps each container's `const_iterator` type alias. So, for example,
338
+ * given that the container for `IListRefTag::Boxed` is `c10::List`, this
339
+ * iterator will wrap a `c10::List::const_iterator`.
340
+ *
341
+ * [Note: MSVC Iterator Debug]
342
+ * ===========================
343
+ * MSVC `vector<T>::iterator` implementation (used in the boxed variant)
344
+ * makes it so this union's destructor, copy-constructor (assignment), and
345
+ * move-constructor (assignment) are implicitly deleted.
346
+ *
347
+ * Therefore, we need to explicitly define them as needed. Follows a list
348
+ * of places where these are needed and their reason:
349
+ *
350
+ * - `Payload` destructor:
351
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
352
+ *
353
+ * - `IListRefIterator` destructor:
354
+ * same as above. However, we need to explicitly call the variant
355
+ * destructor explicitly.
356
+ *
357
+ * - `IListRefIterator` copy-constructor:
358
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
359
+ * than 0.
360
+ */
361
+ template <typename T>
362
+ class IListRefIterator {
363
+ private:
364
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
365
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
366
+ friend class detail::IListRefTagImplBase< \
367
+ IListRefTag::TAG, \
368
+ T, \
369
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
370
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
371
+ #undef DEFINE_FRIEND_CLASS
372
+
373
+ public:
374
+ // C++17 friendly std::iterator implementation
375
+ using iterator_category = std::bidirectional_iterator_tag;
376
+ using value_type = T;
377
+ using difference_type = std::ptrdiff_t;
378
+ using pointer = T*;
379
+ using reference = T&;
380
+
381
+ using unboxed_iterator_type = typename detail::
382
+ IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator;
383
+ using boxed_iterator_type = typename detail::
384
+ IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator;
385
+ using materialized_iterator_type =
386
+ typename detail::MaterializedIListRef<T>::const_iterator;
387
+
388
+ IListRefIterator() : tag_(IListRefTag::None) {}
389
+
390
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
391
+ // See [Note: MSVC Iterator Debug]
392
+ IListRefIterator(const IListRefIterator& iterator)
393
+ : tag_(iterator.tag_) {
394
+ switch (tag_) {
395
+ case IListRefTag::Boxed:
396
+ payload_.boxed_iterator = iterator.payload_.boxed_iterator;
397
+ break;
398
+ case IListRefTag::Unboxed:
399
+ payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
400
+ break;
401
+ case IListRefTag::Materialized:
402
+ payload_.materialized_iterator = iterator.payload_.materialized_iterator;
403
+ break;
404
+ default:
405
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
406
+ }
407
+ }
408
+ #endif
409
+
410
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
411
+ // See [Note: MSVC Iterator Debug]
412
+ ~IListRefIterator() noexcept(false) {
413
+ switch (tag_) {
414
+ case IListRefTag::Boxed:
415
+ payload_.boxed_iterator.~boxed_iterator_type();
416
+ break;
417
+ case IListRefTag::Unboxed:
418
+ payload_.unboxed_iterator.~unboxed_iterator_type();
419
+ break;
420
+ case IListRefTag::Materialized:
421
+ payload_.materialized_iterator.~materialized_iterator_type();
422
+ break;
423
+ default:
424
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
425
+ }
426
+ }
427
+ #endif
428
+
429
+ IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
430
+ payload_.boxed_iterator = boxed;
431
+ }
432
+
433
+ IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
434
+ payload_.unboxed_iterator = unboxed;
435
+ }
436
+
437
+ IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
438
+ payload_.materialized_iterator = materialized;
439
+ }
440
+
441
+ detail::IListRefConstRef<T> operator*() const {
442
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
443
+ }
444
+
445
+ IListRefIterator& operator++() {
446
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
447
+ return *this;
448
+ }
449
+
450
+ IListRefIterator operator++(int) {
451
+ auto old = *this;
452
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
453
+ return old;
454
+ }
455
+
456
+ IListRefIterator& operator--() {
457
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
458
+ return *this;
459
+ }
460
+
461
+ IListRefIterator operator--(int) {
462
+ auto old = *this;
463
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
464
+ return old;
465
+ }
466
+
467
+ bool operator==(const IListRefIterator& rhs) const {
468
+ if (tag_ != rhs.tag_) {
469
+ return false;
470
+ }
471
+ TORCH_ILISTREF_UNWRAP(tag_, {
472
+ auto& rhs_it = ImplT::unwrap(rhs);
473
+ return this_ == rhs_it;
474
+ });
475
+ }
476
+
477
+ bool operator!=(const IListRefIterator& rhs) const {
478
+ return !(*this == rhs);
479
+ }
480
+
481
+ private:
482
+ union Payload {
483
+ boxed_iterator_type boxed_iterator;
484
+ unboxed_iterator_type unboxed_iterator;
485
+ materialized_iterator_type materialized_iterator;
486
+ void* _init_ptr;
487
+ Payload() : _init_ptr(nullptr) {}
488
+ #if defined(_MSC_VER)
489
+ // See [Note: MSVC Iterator Debug]
490
+ ~Payload() {}
491
+ #endif
492
+ };
493
+
494
+ Payload payload_;
495
+ IListRefTag tag_;
496
+ };
497
+
498
+ /*
499
+ * See [Note: IListRef]
500
+ */
501
+ template <typename T>
502
+ class IListRef {
503
+ private:
504
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
505
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
506
+ friend class detail::IListRefTagImplBase< \
507
+ IListRefTag::TAG, \
508
+ T, \
509
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
510
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
511
+ #undef DEFINE_FRIEND_CLASS
512
+
513
+ public:
514
+ using unboxed_type =
515
+ typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type;
516
+ using boxed_type =
517
+ typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type;
518
+ using materialized_type =
519
+ typename detail::MaterializedIListRef<T>;
520
+
521
+ using iterator = IListRefIterator<T>;
522
+ using const_iterator = IListRefIterator<T>;
523
+ using reverse_iterator = std::reverse_iterator<iterator>;
524
+ using value_type = typename iterator::value_type;
525
+
526
+ IListRef() : tag_(IListRefTag::None) {}
527
+
528
+ IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
529
+ payload_.boxed = &boxed;
530
+ }
531
+
532
+ IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
533
+ payload_.unboxed = unboxed;
534
+ }
535
+
536
+ IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) {
537
+ payload_.unboxed = at::ArrayRef<T>(list);
538
+ }
539
+
540
+ template <
541
+ typename... UnboxedConstructorArgs,
542
+ typename = std::enable_if_t<
543
+ std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>>
544
+ IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
545
+ payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...);
546
+ }
547
+
548
+ IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
549
+ payload_.materialized = &materialized;
550
+ }
551
+
552
+ size_t size() const {
553
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
554
+ }
555
+
556
+ bool empty() const {
557
+ return size() == 0;
558
+ }
559
+
560
+ iterator begin() const {
561
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
562
+ }
563
+
564
+ iterator end() const {
565
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
566
+ }
567
+
568
+ detail::IListRefConstRef<T> front() const {
569
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
570
+ }
571
+
572
+ /*
573
+ * Materializes the `IListRef` into a `std::vector`.
574
+ *
575
+ * This should be used when one wishes to either:
576
+ *
577
+ * - iterate over the list more than once: each `IListRefIterator`
578
+ * member function call has to go through a switch, introducing
579
+ * non-negligible overhead
580
+ *
581
+ * - randomly access an arbitrary element using `operator[]`:
582
+ * same reason as above
583
+ */
584
+ detail::MaterializedIListRef<T> materialize() const {
585
+ if (isMaterialized()) {
586
+ return toMaterialized();
587
+ }
588
+
589
+ detail::MaterializedIListRef<T> materialized;
590
+ materialized.reserve(size());
591
+ for (const auto& t : *this) {
592
+ materialized.emplace_back(t);
593
+ }
594
+ return materialized;
595
+ }
596
+
597
+ #define DEFINE_CHECK(TAG, ...) \
598
+ bool is##TAG() const { \
599
+ return tag_ == IListRefTag::TAG; \
600
+ }
601
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK)
602
+ #undef DEFINE_CHECK
603
+
604
+ bool isNone() const {
605
+ return tag_ == IListRefTag::None;
606
+ }
607
+
608
+ #define DEFINE_CASTING(TAG, ...) \
609
+ const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \
610
+ to##TAG() const { \
611
+ TORCH_INTERNAL_ASSERT(is##TAG()); \
612
+ return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this); \
613
+ }
614
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING)
615
+ #undef DEFINE_CASTING
616
+
617
+ private:
618
+ union Payload {
619
+ const boxed_type* boxed;
620
+ unboxed_type unboxed;
621
+ const materialized_type* materialized;
622
+ Payload() : boxed(nullptr) {}
623
+ };
624
+
625
+ Payload payload_;
626
+ IListRefTag tag_;
627
+ };
628
+
629
+ } // namespace c10
630
+
631
+ #include <ATen/core/IListRef_inl.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/List.h>
4
+ #include <ATen/core/Tensor.h>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+ class OptionalTensorRef;
9
+ }
10
+
11
+
12
+ namespace c10::detail {
13
+
14
+ /*
15
+ * Specializations of `IListRefTagImplBase` that implement the default
16
+ * implementation for `IListRefTag::Unboxed`.
17
+ */
18
+ template <typename T, typename ListElemT>
19
+ class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
20
+ public:
21
+ using elem_type = ListElemT;
22
+ using list_type = ArrayRef<elem_type>;
23
+
24
+ /*
25
+ * These `unwrap` static methods unwraps the inner containers out
26
+ * of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
27
+ * the macro `TORCH_ILISTREF_UNWRAP` is called.
28
+ */
29
+ static const list_type& unwrap(const IListRef<T>& ilist) {
30
+ return ilist.payload_.unboxed;
31
+ }
32
+
33
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
34
+ return it.payload_.unboxed_iterator;
35
+ }
36
+
37
+ static const typename list_type::const_iterator& unwrap(
38
+ const IListRefIterator<T>& it) {
39
+ return it.payload_.unboxed_iterator;
40
+ }
41
+
42
+ /*
43
+ * We have these function (besides the `unwrap`s above) because the
44
+ * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
45
+ * weren't syntatically equal for the existing tags at the time
46
+ * (`Unboxed` and `Boxed`).
47
+ */
48
+ static IListRefConstRef<T> front(const list_type& lst) {
49
+ return lst.front();
50
+ }
51
+
52
+ static IListRefConstRef<T> iterator_get(
53
+ const typename list_type::const_iterator& it) {
54
+ return *it;
55
+ }
56
+ };
57
+
58
+ /*
59
+ * Specializations of `IListRefTagImplBase` that implement the default
60
+ * implementation for `IListRefTag::Boxed`.
61
+ */
62
+ template <typename T, typename ListElemT>
63
+ class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
64
+ public:
65
+ using elem_type = ListElemT;
66
+ using list_type = List<elem_type>;
67
+
68
+ static const list_type& unwrap(const IListRef<T>& ilist) {
69
+ return *ilist.payload_.boxed;
70
+ }
71
+
72
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
73
+ return it.payload_.boxed_iterator;
74
+ }
75
+
76
+ static const typename list_type::const_iterator& unwrap(
77
+ const IListRefIterator<T>& it) {
78
+ return it.payload_.boxed_iterator;
79
+ }
80
+
81
+ static IListRefConstRef<T> front(const list_type& lst) {
82
+ return lst[0];
83
+ }
84
+
85
+ static IListRefConstRef<T> iterator_get(
86
+ const typename list_type::const_iterator& it) {
87
+ return (*it).get().toTensor();
88
+ }
89
+ };
90
+
91
+ /*
92
+ * Specializations of `IListRefTagImplBase` that implement the default
93
+ * implementation for `IListRefTag::Materialized`.
94
+ */
95
+ template <typename T>
96
+ class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
97
+ public:
98
+ using elem_type = MaterializedIListRefElem<T>;
99
+ using list_type = MaterializedIListRef<T>;
100
+
101
+ static const list_type& unwrap(const IListRef<T>& ilist) {
102
+ return *ilist.payload_.materialized;
103
+ }
104
+
105
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
106
+ return it.payload_.materialized_iterator;
107
+ }
108
+
109
+ static const typename list_type::const_iterator& unwrap(
110
+ const IListRefIterator<T>& it) {
111
+ return it.payload_.materialized_iterator;
112
+ }
113
+
114
+ static IListRefConstRef<T> front(const list_type& lst) {
115
+ return lst[0];
116
+ }
117
+
118
+ static IListRefConstRef<T> iterator_get(
119
+ const typename list_type::const_iterator& it) {
120
+ return *it;
121
+ }
122
+ };
123
+
124
+ /*
125
+ * [Note: ITensorListRef]
126
+ * Specializations necessary for `IListRef<at::Tensor>` type.
127
+ *
128
+ * Since the default implementations are usually done with supporting
129
+ * `Tensor` in mind, we only have to inherit from the base implementations.
130
+ */
131
+ template <>
132
+ class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
133
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
134
+
135
+ template <>
136
+ class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
137
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
138
+
139
+ template <>
140
+ class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
141
+ : public IListRefTagImplBase<
142
+ IListRefTag::Materialized,
143
+ at::Tensor,
144
+ MaterializedIListRefElem<at::Tensor>> {};
145
+
146
+ /*
147
+ * [Note: IOptTensorListRef]
148
+ * Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
149
+ *
150
+ * We can't get an `at::OptionalTensorRef` directly from an instance of
151
+ * `List<optional<Tensor>>` (the type that corresponds to the boxed world).
152
+ *
153
+ * So, the default implementation won't help us. Thus, we have to implement
154
+ * this method ourselves.
155
+ */
156
+ template <>
157
+ class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
158
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
159
+
160
+ template <>
161
+ class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
162
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> {
163
+
164
+ public:
165
+ /*
166
+ * Given an instance of the types corresponding to the `Boxed` tag, we override
167
+ * the default implementation, so that we can return a `at::OptionalTensorRef`.
168
+ */
169
+ static IListRefConstRef<at::OptionalTensorRef> iterator_get(
170
+ const typename list_type::const_iterator& it) {
171
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference")
172
+ const auto& ivalue = (*it).get();
173
+ C10_DIAGNOSTIC_POP()
174
+ if (!ivalue.isNone()) {
175
+ const auto& tensor = ivalue.toTensor();
176
+ return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
177
+ }
178
+ return {};
179
+ }
180
+ };
181
+
182
+ template <>
183
+ class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
184
+ : public IListRefTagImplBase<
185
+ IListRefTag::Materialized,
186
+ at::OptionalTensorRef,
187
+ MaterializedIListRefElem<at::OptionalTensorRef>> {};
188
+
189
+ } // namespace c10::detail
190
+
191
+
192
+ namespace at {
193
+
194
+ // [Note: ITensorListRef]
195
+ using ITensorListRef = c10::IListRef<at::Tensor>;
196
+ using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
197
+ using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
198
+ // [Note: IOptTensorListRef]
199
+ using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
200
+ using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
201
+ using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
202
+
203
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // The legacy mechanism for dispatching operators in ATen is a Type
4
+ // object, which is essentially a giant virtual dispatch table
5
+ // for every operation we support dynamically dispatching over.
6
+ //
7
+ // This has been deprecated in favor of ATenDispatch, and in the future,
8
+ // c10 dispatcher.
9
+ // TODO: Clean up what remains here
10
+
11
+ #include <c10/core/impl/LocalDispatchKeySet.h>
12
+
13
+ namespace at {
14
+
15
+ // A RAII, thread local (!) guard that will disable dispatch to variable
16
+ // handler.
17
+ //
18
+ // NOTE [ Treating Variables as non-Variables in type dispatch ]
19
+ //
20
+ // What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes
21
+ // dispatches on ATen functions to go to the non-variable implementation,
22
+ // bypassing autograd handling (and also profiling and tracing).
23
+ //
24
+ // To understand why this guard exists, it's helpful to understand the history
25
+ // behind how Variable was implemented. Previously, Variables were implemented
26
+ // as a wrapper on Tensors; so the act of processing a Variable involved
27
+ // unwrapping the underlying Tensor, and then calling the underlying base
28
+ // operation on /that/ operation
29
+ //
30
+ // However, after the Variable/Tensor merge, there is no concept of unwrapping
31
+ // a tensor anymore. If you just call the operation on the same variable
32
+ // again inside your VariableType handler, you'll dispatch back to
33
+ // VariableType, which is not what we want.
34
+ //
35
+ // The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which
36
+ // when enabled will cause `legacyTensorType()` and `getType()` to always return
37
+ // non-Variable type, even if the tensor being called on is a variable.
38
+
39
+ /* Note [AutoDispatchBelowAutograd]
40
+ * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used
41
+ * for kernel implementations and customized C++ kernels.
42
+ * If you are looking for a guard to run workload in inference mode, please use
43
+ * c10::InferenceMode RAII which is user facing API.
44
+ * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode)
45
+ * was used in the user code for inference-only workload, this was under risk of
46
+ * producing wrong results silently in some edge cases. For example:
47
+ * ```
48
+ * torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
49
+ * torch::Tensor out = s * s;
50
+ * {
51
+ * at::AutoDispatchBelowAutograd guard;
52
+ * s.add_(1); // Skips version bump on `s`.
53
+ * }
54
+ * // WRONG GRADIENT! s.grad() are now computed using `s` value after the
55
+ * // inplace update.
56
+ * out.backward(torch::ones_like(out));
57
+ * ```
58
+ * Users should use `c10::InferenceMode` here so that it'll properly throw an
59
+ * error saying "one of the variables needed for gradient computation has be modified."
60
+ */
61
+ struct TORCH_API AutoDispatchBelowAutograd {
62
+ AutoDispatchBelowAutograd() :
63
+ autograd_guard_(c10::autograd_dispatch_keyset) {
64
+ }
65
+
66
+ // disable all autograd dispatch keys
67
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
68
+ };
69
+
70
+ // TODO: AutoNonVariableTypeMode should be removed in release 1.10.
71
+ struct TORCH_API AutoNonVariableTypeMode {
72
+ AutoNonVariableTypeMode(bool enabled = true) :
73
+ autograd_guard_(c10::autograd_dispatch_keyset) {
74
+ TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. "
75
+ "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, "
76
+ "If you are looking for a user facing API to enable running your inference-only "
77
+ "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code "
78
+ "is under risk of producing silent wrong result in some edge cases. "
79
+ "See Note [AutoDispatchBelowAutograd] for more details.");
80
+ TORCH_INTERNAL_ASSERT(enabled);
81
+ }
82
+
83
+ // disable all autograd dispatch keys
84
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
85
+ };
86
+
87
+ struct TORCH_API AutoDispatchSkipFunctionalize {
88
+ AutoDispatchSkipFunctionalize() :
89
+ dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) {
90
+ }
91
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
92
+ };
93
+
94
+ /* Note [AutoDispatchBelowADInplaceOrView]
95
+ * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
96
+ * before we split inplace & view ops out of VariableType kernel.
97
+ * Note this guard is used in VariableType kernels for functional ops
98
+ * as well as ADInplaceOrView kernels for inplace/view ops to enforce the
99
+ * Invariant:
100
+ * Once you are in VariableType/ADInplaceOrView kernel for an op,
101
+ * you never go back to a kernel on same dispatch key until
102
+ * you finish the current op.
103
+ */
104
+ struct TORCH_API AutoDispatchBelowADInplaceOrView {
105
+ AutoDispatchBelowADInplaceOrView() :
106
+ dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) {
107
+ }
108
+ // disable Autograd & ADInplaceOrView dispatch keys
109
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
110
+ };
111
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue_to.h>
4
+ #include <ATen/core/jit_type_base.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/macros/Export.h>
7
+ #include <c10/util/TypeTraits.h>
8
+ #include <c10/util/TypeList.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <c10/util/ArrayRef.h>
11
+ #include <optional>
12
+ #include <vector>
13
+
14
+ namespace at {
15
+ class Tensor;
16
+ }
17
+ namespace c10 {
18
+ struct IValue;
19
+ template<class T> class List;
20
+ struct Type;
21
+
22
+ namespace detail {
23
+
24
+ struct ListImpl final : public c10::intrusive_ptr_target {
25
+ using list_type = std::vector<IValue>;
26
+
27
+ explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
28
+
29
+ list_type list;
30
+
31
+ TypePtr elementType;
32
+
33
+ intrusive_ptr<ListImpl> copy() const {
34
+ return make_intrusive<ListImpl>(list, elementType);
35
+ }
36
+ friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
37
+ };
38
+ }
39
+
40
+ namespace impl {
41
+
42
+ template<class T, class Iterator> class ListIterator;
43
+
44
+ template<class T, class Iterator> class ListElementReference;
45
+
46
+ template<class T, class Iterator>
47
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
48
+
49
+ template<class T, class Iterator>
50
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
51
+
52
+ template<class T, class Iterator>
53
+ bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
54
+
55
+ template<class T>
56
+ struct ListElementConstReferenceTraits {
57
+ // In the general case, we use IValue::to().
58
+ using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
59
+ };
60
+
61
+ // There is no to() overload for std::optional<std::string>.
62
+ template<>
63
+ struct ListElementConstReferenceTraits<std::optional<std::string>> {
64
+ using const_reference = std::optional<std::reference_wrapper<const std::string>>;
65
+ };
66
+
67
+ template<class T, class Iterator>
68
+ class ListElementReference final {
69
+ public:
70
+ operator std::conditional_t<
71
+ std::is_reference_v<typename c10::detail::
72
+ ivalue_to_const_ref_overload_return<T>::type>,
73
+ const T&,
74
+ T>() const;
75
+
76
+ ListElementReference& operator=(T&& new_value) &&;
77
+
78
+ ListElementReference& operator=(const T& new_value) &&;
79
+
80
+ // assigning another ref to this assigns the underlying value
81
+ ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
82
+
83
+ const IValue& get() const& {
84
+ return *iterator_;
85
+ }
86
+
87
+ friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
88
+
89
+ ListElementReference(const ListElementReference&) = delete;
90
+ ListElementReference& operator=(const ListElementReference&) = delete;
91
+ ~ListElementReference() = default;
92
+
93
+ private:
94
+ ListElementReference(Iterator iter)
95
+ : iterator_(iter) {}
96
+
97
+ // allow moving, but only our friends (i.e. the List class) can move us
98
+ ListElementReference(ListElementReference&&) noexcept = default;
99
+ ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
100
+ iterator_ = std::move(rhs.iterator_);
101
+ return *this;
102
+ }
103
+
104
+ friend class List<T>;
105
+ friend class ListIterator<T, Iterator>;
106
+
107
+ Iterator iterator_;
108
+ };
109
+
110
+ // this wraps vector::iterator to make sure user code can't rely
111
+ // on it being the type of the underlying vector.
112
+ template <class T, class Iterator>
113
+ class ListIterator final {
114
+ public:
115
+ // C++17 friendly std::iterator implementation
116
+ using iterator_category = std::random_access_iterator_tag;
117
+ using value_type = T;
118
+ using difference_type = std::ptrdiff_t;
119
+ using pointer = T*;
120
+ using reference = ListElementReference<T, Iterator>;
121
+
122
+ explicit ListIterator() = default;
123
+ ~ListIterator() = default;
124
+
125
+ ListIterator(const ListIterator&) = default;
126
+ ListIterator(ListIterator&&) noexcept = default;
127
+ ListIterator& operator=(const ListIterator&) = default;
128
+ ListIterator& operator=(ListIterator&&) noexcept = default;
129
+
130
+ ListIterator& operator++() {
131
+ ++iterator_;
132
+ return *this;
133
+ }
134
+
135
+ ListIterator operator++(int) {
136
+ ListIterator copy(*this);
137
+ ++*this;
138
+ return copy;
139
+ }
140
+
141
+ ListIterator& operator--() {
142
+ --iterator_;
143
+ return *this;
144
+ }
145
+
146
+ ListIterator operator--(int) {
147
+ ListIterator copy(*this);
148
+ --*this;
149
+ return copy;
150
+ }
151
+
152
+ ListIterator& operator+=(typename List<T>::size_type offset) {
153
+ iterator_ += offset;
154
+ return *this;
155
+ }
156
+
157
+ ListIterator& operator-=(typename List<T>::size_type offset) {
158
+ iterator_ -= offset;
159
+ return *this;
160
+ }
161
+
162
+ ListIterator operator+(typename List<T>::size_type offset) const {
163
+ return ListIterator{iterator_ + offset};
164
+ }
165
+
166
+ ListIterator operator-(typename List<T>::size_type offset) const {
167
+ return ListIterator{iterator_ - offset};
168
+ }
169
+
170
+ friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
171
+ return lhs.iterator_ - rhs.iterator_;
172
+ }
173
+
174
+ ListElementReference<T, Iterator> operator*() const {
175
+ return {iterator_};
176
+ }
177
+
178
+ ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
179
+ return {iterator_ + offset};
180
+ }
181
+
182
+ private:
183
+ explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
184
+
185
+ Iterator iterator_;
186
+
187
+ friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
188
+ return lhs.iterator_ == rhs.iterator_;
189
+ }
190
+
191
+ friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
192
+ return !(lhs == rhs);
193
+ }
194
+
195
+ friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
196
+ return lhs.iterator_ < rhs.iterator_;
197
+ }
198
+
199
+ friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
200
+ return lhs.iterator_ <= rhs.iterator_;
201
+ }
202
+
203
+ friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
204
+ return lhs.iterator_ > rhs.iterator_;
205
+ }
206
+
207
+ friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
208
+ return lhs.iterator_ >= rhs.iterator_;
209
+ }
210
+
211
+ friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
212
+ friend class List<T>;
213
+ };
214
+
215
+ template<class T> List<T> toTypedList(List<IValue> list);
216
+ template<class T> List<IValue> toList(List<T>&& list);
217
+ template<class T> List<IValue> toList(const List<T>& list);
218
+ const IValue* ptr_to_first_element(const List<IValue>& list);
219
+ }
220
+
221
+ /**
222
+ * An object of this class stores a list of values of type T.
223
+ *
224
+ * This is a pointer type. After a copy, both Lists
225
+ * will share the same storage:
226
+ *
227
+ * > List<int> a;
228
+ * > List<int> b = a;
229
+ * > b.push_back("three");
230
+ * > ASSERT("three" == a.get(0));
231
+ *
232
+ * We use this class in the PyTorch kernel API instead of
233
+ * std::vector<T>, because that allows us to do optimizations
234
+ * and switch out the underlying list implementation without
235
+ * breaking backwards compatibility for the kernel API.
236
+ */
237
+ template<class T>
238
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
239
+ class List final {
240
+ private:
241
+ // This is an intrusive_ptr because List is a pointer type.
242
+ // Invariant: This will never be a nullptr, there will always be a valid
243
+ // ListImpl.
244
+ c10::intrusive_ptr<c10::detail::ListImpl> impl_;
245
+
246
+ using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
247
+ using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
248
+
249
+ public:
250
+ using value_type = T;
251
+ using size_type = typename c10::detail::ListImpl::list_type::size_type;
252
+ using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
253
+ using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
254
+ using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
255
+
256
+ /**
257
+ * Constructs an empty list.
258
+ */
259
+ explicit List();
260
+
261
+ /**
262
+ * Constructs a list with some initial values.
263
+ * Example:
264
+ * List<int> a({2, 3, 4});
265
+ */
266
+ List(std::initializer_list<T> initial_values);
267
+ explicit List(ArrayRef<T> initial_values);
268
+
269
+ /**
270
+ * Create a generic list with runtime type information.
271
+ * This only works for c10::impl::GenericList and is not part of the public API
272
+ * but only supposed to be used internally by PyTorch.
273
+ */
274
+ explicit List(TypePtr elementType);
275
+
276
+ List(const List&) = default;
277
+ List& operator=(const List&) = default;
278
+ ~List() = default;
279
+
280
+ /**
281
+ * Create a new List pointing to a deep copy of the same data.
282
+ * The List returned is a new list with separate storage.
283
+ * Changes in it are not reflected in the original list or vice versa.
284
+ */
285
+ List copy() const;
286
+
287
+ /**
288
+ * Returns the element at specified location pos, with bounds checking.
289
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
290
+ */
291
+ internal_const_reference_type get(size_type pos) const;
292
+
293
+ /**
294
+ * Moves out the element at the specified location pos and returns it, with bounds checking.
295
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
296
+ * The list contains an invalid element at position pos afterwards. Any operations
297
+ * on it before re-setting it are invalid.
298
+ */
299
+ value_type extract(size_type pos) const;
300
+
301
+ /**
302
+ * Returns a reference to the element at specified location pos, with bounds checking.
303
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
304
+ *
305
+ * You cannot store the reference, but you can read it and assign new values to it:
306
+ *
307
+ * List<int64_t> list = ...;
308
+ * list[2] = 5;
309
+ * int64_t v = list[1];
310
+ */
311
+ internal_const_reference_type operator[](size_type pos) const;
312
+
313
+ internal_reference_type operator[](size_type pos);
314
+
315
+ /**
316
+ * Assigns a new value to the element at location pos.
317
+ */
318
+ void set(size_type pos, const value_type& value) const;
319
+
320
+ /**
321
+ * Assigns a new value to the element at location pos.
322
+ */
323
+ void set(size_type pos, value_type&& value) const;
324
+
325
+ /**
326
+ * Returns an iterator to the first element of the container.
327
+ * If the container is empty, the returned iterator will be equal to end().
328
+ */
329
+ iterator begin() const;
330
+
331
+ /**
332
+ * Returns an iterator to the element following the last element of the container.
333
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
334
+ */
335
+ iterator end() const;
336
+
337
+ /**
338
+ * Checks if the container has no elements.
339
+ */
340
+ bool empty() const;
341
+
342
+ /**
343
+ * Returns the number of elements in the container
344
+ */
345
+ size_type size() const;
346
+
347
+ /**
348
+ * Increase the capacity of the vector to a value that's greater or equal to new_cap.
349
+ */
350
+ void reserve(size_type new_cap) const;
351
+
352
+ /**
353
+ * Erases all elements from the container. After this call, size() returns zero.
354
+ * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
355
+ */
356
+ void clear() const;
357
+
358
+ /**
359
+ * Inserts value before pos.
360
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
361
+ */
362
+ iterator insert(iterator pos, const T& value) const;
363
+
364
+ /**
365
+ * Inserts value before pos.
366
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
367
+ */
368
+ iterator insert(iterator pos, T&& value) const;
369
+
370
+ /**
371
+ * Inserts a new element into the container directly before pos.
372
+ * The new element is constructed with the given arguments.
373
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
374
+ */
375
+ template<class... Args>
376
+ iterator emplace(iterator pos, Args&&... value) const;
377
+
378
+ /**
379
+ * Appends the given element value to the end of the container.
380
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
381
+ */
382
+ void push_back(const T& value) const;
383
+
384
+ /**
385
+ * Appends the given element value to the end of the container.
386
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
387
+ */
388
+ void push_back(T&& value) const;
389
+
390
+ /**
391
+ * Appends the given list to the end of the container. Uses at most one memory allocation.
392
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
393
+ */
394
+ void append(List<T> lst) const;
395
+
396
+ /**
397
+ * Appends the given element value to the end of the container.
398
+ * The new element is constructed with the given arguments.
399
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
400
+ */
401
+ template<class... Args>
402
+ void emplace_back(Args&&... args) const;
403
+
404
+ /**
405
+ * Removes the element at pos.
406
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
407
+ */
408
+ iterator erase(iterator pos) const;
409
+
410
+ /**
411
+ * Removes the elements in the range [first, last).
412
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
413
+ */
414
+ iterator erase(iterator first, iterator last) const;
415
+
416
+ /**
417
+ * Removes the last element of the container.
418
+ * Calling pop_back on an empty container is undefined.
419
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
420
+ */
421
+ void pop_back() const;
422
+
423
+ /**
424
+ * Resizes the container to contain count elements.
425
+ * If the current size is less than count, additional default-inserted elements are appended.
426
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
427
+ */
428
+ void resize(size_type count) const;
429
+
430
+ /**
431
+ * Resizes the container to contain count elements.
432
+ * If the current size is less than count, additional copies of value are appended.
433
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
434
+ */
435
+ void resize(size_type count, const T& value) const;
436
+
437
+ /**
438
+ * Value equality comparison. This function implements Python-like semantics for
439
+ * equality: two lists with the same identity (e.g. same pointer) trivially
440
+ * compare equal, otherwise each element is compared for equality.
441
+ */
442
+ template <class T_>
443
+ friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
444
+
445
+ template <class T_>
446
+ friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
447
+
448
+ /**
449
+ * Identity comparison. Returns true if and only if `rhs` represents the same
450
+ * List object as `this`.
451
+ */
452
+ bool is(const List<T>& rhs) const;
453
+
454
+ std::vector<T> vec() const;
455
+
456
+ /**
457
+ * Returns the number of Lists currently pointing to this same list.
458
+ * If this is the only instance pointing to this list, returns 1.
459
+ */
460
+ // TODO Test use_count
461
+ size_t use_count() const;
462
+
463
+ TypePtr elementType() const;
464
+
465
+ // See [unsafe set type] for why this exists.
466
+ void unsafeSetElementType(TypePtr t);
467
+
468
+ private:
469
+ explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
470
+ explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
471
+ friend struct IValue;
472
+ template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
473
+ template<class T_> friend List<IValue> impl::toList(List<T_>&&);
474
+ template<class T_> friend List<IValue> impl::toList(const List<T_>&);
475
+ friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
476
+ };
477
+
478
+ namespace impl {
479
+ // GenericList is how IValue stores lists. It is, however, not part of the
480
+ // public API. Kernels should use Lists with concrete types instead
481
+ // (maybe except for some internal prim ops).
482
+ using GenericList = List<IValue>;
483
+
484
+ }
485
+ }
486
+
487
+ namespace torch {
488
+ template<class T> using List = c10::List<T>;
489
+ }
490
+
491
+ #include <ATen/core/List_inl.h> // IWYU pragma: keep
.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/jit_type_base.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ namespace c10 {
7
+
8
+ template<class T> decltype(auto) getTypePtr();
9
+ std::string toString(const Type& type);
10
+
11
+ template<class T>
12
+ List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
13
+ : impl_(std::move(elements)) {}
14
+
15
+ template<class T>
16
+ List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
17
+ : impl_(elements) {}
18
+
19
+ template<class T>
20
+ List<T>::List()
21
+ : List(make_intrusive<c10::detail::ListImpl>(
22
+ typename c10::detail::ListImpl::list_type(),
23
+ getTypePtr<T>())) {
24
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
25
+ }
26
+
27
+ template<class T>
28
+ List<T>::List(ArrayRef<T> values)
29
+ : List(make_intrusive<c10::detail::ListImpl>(
30
+ typename c10::detail::ListImpl::list_type(),
31
+ getTypePtr<T>())) {
32
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
33
+ impl_->list.reserve(values.size());
34
+ for (const T& element : values) {
35
+ impl_->list.push_back(element);
36
+ }
37
+ }
38
+
39
+ template<class T>
40
+ List<T>::List(std::initializer_list<T> initial_values)
41
+ : List(ArrayRef<T>(initial_values)) {
42
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
43
+ }
44
+
45
+ template<class T>
46
+ List<T>::List(TypePtr elementType)
47
+ : List(make_intrusive<c10::detail::ListImpl>(
48
+ typename c10::detail::ListImpl::list_type(),
49
+ std::move(elementType))) {
50
+ static_assert(std::is_same_v<T, IValue> || std::is_same_v<T, c10::intrusive_ptr<ivalue::Future>>,
51
+ "This constructor is only valid for c10::impl::GenericList or List<Future>.");
52
+ }
53
+
54
+ namespace impl {
55
+ template<class T>
56
+ List<T> toTypedList(impl::GenericList list) {
57
+ // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
58
+ // because upcasting would allow people to add types into the new list that would break the old list.
59
+ // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
60
+ // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
61
+ // without having to copy it. This is also used to provide backwards compatibility with some old models
62
+ // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
63
+ // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
64
+ // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
65
+ TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66
+ || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
67
+ , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
68
+ return List<T>(std::move(list.impl_));
69
+ }
70
+
71
+ template<class T>
72
+ impl::GenericList toList(List<T>&& list) {
73
+ return GenericList(std::move(list.impl_));
74
+ }
75
+ template<class T>
76
+ impl::GenericList toList(const List<T>& list) {
77
+ return GenericList(list.impl_);
78
+ }
79
+ }
80
+
81
+ template<class T>
82
+ List<T> List<T>::copy() const {
83
+ return List<T>(impl_->copy());
84
+ }
85
+
86
+ namespace detail {
87
+ template<class T>
88
+ T list_element_to(T element) {
89
+ return element;
90
+ }
91
+ template<class T>
92
+ T list_element_to(const IValue& element) {
93
+ return element.template to<T>();
94
+ }
95
+ template<class T>
96
+ T list_element_to(IValue&& element) {
97
+ return std::move(element).template to<T>();
98
+ }
99
+ template<class T>
100
+ struct ListElementFrom {
101
+ static IValue from(const T& element) {
102
+ return element;
103
+ }
104
+ static IValue from(T&& element) {
105
+ return std::move(element);
106
+ }
107
+ };
108
+ template<>
109
+ struct ListElementFrom<IValue> {
110
+ static const IValue& from(const IValue& element) {
111
+ return element;
112
+ }
113
+ static IValue&& from(IValue&& element) {
114
+ return std::move(element);
115
+ }
116
+ };
117
+ }
118
+
119
+ namespace impl {
120
+
121
+ template <class T, class Iterator>
122
+ ListElementReference<T, Iterator>::operator std::conditional_t<
123
+ std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
124
+ T>::type>,
125
+ const T&,
126
+ T>() const {
127
+ return iterator_->template to<T>();
128
+ }
129
+
130
+ template<class T, class Iterator>
131
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
132
+ *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
133
+ return *this;
134
+ }
135
+
136
+ template<class T, class Iterator>
137
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
138
+ *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
139
+ return *this;
140
+ }
141
+
142
+ template<class T, class Iterator>
143
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
144
+ *iterator_ = *rhs.iterator_;
145
+ return *this;
146
+ }
147
+
148
+ template<class T, class Iterator>
149
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
150
+ std::swap(*lhs.iterator_, *rhs.iterator_);
151
+ }
152
+
153
+ template<class T, class Iterator>
154
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
155
+ const T& lhs_tmp = lhs;
156
+ return lhs_tmp == rhs;
157
+ }
158
+
159
+ template<class T, class Iterator>
160
+ inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
161
+ return rhs == lhs;
162
+ }
163
+
164
+ template<class T>
165
+ inline typename ListElementConstReferenceTraits<T>::const_reference
166
+ list_element_to_const_ref(const IValue& element) {
167
+ return element.template to<T>();
168
+ }
169
+
170
+ template<>
171
+ inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
172
+ list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
173
+ return element.toOptionalStringRef();
174
+ }
175
+
176
+ } // namespace impl
177
+
178
+ template<class T>
179
+ void List<T>::set(size_type pos, const value_type& value) const {
180
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
181
+ }
182
+
183
+ template<class T>
184
+ void List<T>::set(size_type pos, value_type&& value) const {
185
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
186
+ }
187
+
188
+ template<class T>
189
+ typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
190
+ return operator[](pos);
191
+ }
192
+
193
+ template<class T>
194
+ typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
195
+ return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
196
+ }
197
+
198
+ template<class T>
199
+ typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
200
+ static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
201
+ return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
202
+ }
203
+
204
+ template<class T>
205
+ typename List<T>::value_type List<T>::extract(size_type pos) const {
206
+ auto& elem = impl_->list.at(pos);
207
+ auto result = c10::detail::list_element_to<T>(std::move(elem));
208
+ // Reset the list element to a T() instead of None to keep it correctly typed
209
+ elem = c10::detail::ListElementFrom<T>::from(T{});
210
+ return result;
211
+ }
212
+
213
+ template<class T>
214
+ typename List<T>::iterator List<T>::begin() const {
215
+ return iterator(impl_->list.begin());
216
+ }
217
+
218
+ template<class T>
219
+ typename List<T>::iterator List<T>::end() const {
220
+ return iterator(impl_->list.end());
221
+ }
222
+
223
+ template<class T>
224
+ bool List<T>::empty() const {
225
+ return impl_->list.empty();
226
+ }
227
+
228
+ template<class T>
229
+ typename List<T>::size_type List<T>::size() const {
230
+ return impl_->list.size();
231
+ }
232
+
233
+ template<class T>
234
+ void List<T>::reserve(size_type new_cap) const {
235
+ impl_->list.reserve(new_cap);
236
+ }
237
+
238
+ template<class T>
239
+ void List<T>::clear() const {
240
+ impl_->list.clear();
241
+ }
242
+
243
+ template<class T>
244
+ typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
245
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
246
+ }
247
+
248
+ template<class T>
249
+ typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
250
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
251
+ }
252
+
253
+ template<class T>
254
+ template<class... Args>
255
+ typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
256
+ // TODO Use list_element_from?
257
+ return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
258
+ }
259
+
260
+ template<class T>
261
+ void List<T>::push_back(const T& value) const {
262
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
263
+ }
264
+
265
+ template<class T>
266
+ void List<T>::push_back(T&& value) const {
267
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
268
+ }
269
+
270
+ template<class T>
271
+ void List<T>::append(List<T> b) const {
272
+ if (b.use_count() == 1) {
273
+ impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
274
+ } else {
275
+ impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
276
+ }
277
+ }
278
+
279
+ template<class T>
280
+ template<class... Args>
281
+ void List<T>::emplace_back(Args&&... args) const {
282
+ // TODO Use list_element_from?
283
+ impl_->list.push_back(T(std::forward<Args>(args)...));
284
+ }
285
+
286
+ template<class T>
287
+ typename List<T>::iterator List<T>::erase(iterator pos) const {
288
+ return iterator { impl_->list.erase(pos.iterator_) };
289
+ }
290
+
291
+ template<class T>
292
+ typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
293
+ return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
294
+ }
295
+
296
+ template<class T>
297
+ void List<T>::pop_back() const {
298
+ impl_->list.pop_back();
299
+ }
300
+
301
+ template<class T>
302
+ void List<T>::resize(size_type count) const {
303
+ impl_->list.resize(count, T{});
304
+ }
305
+
306
+ template<class T>
307
+ void List<T>::resize(size_type count, const T& value) const {
308
+ impl_->list.resize(count, value);
309
+ }
310
+
311
+ template<class T>
312
+ bool operator==(const List<T>& lhs, const List<T>& rhs) {
313
+ // Lists with the same identity trivially compare equal.
314
+ if (lhs.impl_ == rhs.impl_) {
315
+ return true;
316
+ }
317
+
318
+ // Otherwise, just compare values directly.
319
+ return *lhs.impl_ == *rhs.impl_;
320
+ }
321
+
322
+ template<class T>
323
+ bool operator!=(const List<T>& lhs, const List<T>& rhs) {
324
+ return !(lhs == rhs);
325
+ }
326
+
327
+ template<class T>
328
+ bool List<T>::is(const List<T>& rhs) const {
329
+ return this->impl_ == rhs.impl_;
330
+ }
331
+
332
+ template<class T>
333
+ std::vector<T> List<T>::vec() const {
334
+ std::vector<T> result(begin(), end());
335
+ return result;
336
+ }
337
+
338
+ template<class T>
339
+ size_t List<T>::use_count() const {
340
+ return impl_.use_count();
341
+ }
342
+
343
+ template <class T>
344
+ TypePtr List<T>::elementType() const {
345
+ return impl_->elementType;
346
+ }
347
+
348
+ template <class T>
349
+ void List<T>::unsafeSetElementType(TypePtr t) {
350
+ impl_->elementType = std::move(t);
351
+ }
352
+
353
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+
5
+ // define constants like M_PI and C keywords for MSVC
6
+ #ifdef _MSC_VER
7
+ #ifndef _USE_MATH_DEFINES
8
+ #define _USE_MATH_DEFINES
9
+ #endif
10
+ #include <math.h>
11
+ #endif
12
+
13
+ #include <array>
14
+ #include <cmath>
15
+ #include <cstdint>
16
+
17
+ namespace at {
18
+
19
+ constexpr int MERSENNE_STATE_N = 624;
20
+ constexpr int MERSENNE_STATE_M = 397;
21
+ constexpr uint32_t MATRIX_A = 0x9908b0df;
22
+ constexpr uint32_t UMASK = 0x80000000;
23
+ constexpr uint32_t LMASK = 0x7fffffff;
24
+
25
+ /**
26
+ * Note [Mt19937 Engine implementation]
27
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28
+ * Originally implemented in:
29
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
30
+ * and modified with C++ constructs. Moreover the state array of the engine
31
+ * has been modified to hold 32 bit uints instead of 64 bits.
32
+ *
33
+ * Note that we reimplemented mt19937 instead of using std::mt19937 because,
34
+ * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
35
+ * by default and following are the benchmark numbers (benchmark code can be found at
36
+ * https://github.com/syed-ahmed/benchmark-rngs):
37
+ *
38
+ * with -O2
39
+ * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
40
+ * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
41
+ * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
42
+ * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
43
+ *
44
+ * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
45
+ * however we can't use std::uniform_real_distribution because of this bug:
46
+ * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
47
+ * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
48
+ * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
49
+ * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
50
+ * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
51
+ *
52
+ * Copyright notice:
53
+ * A C-program for MT19937, with initialization improved 2002/2/10.
54
+ * Coded by Takuji Nishimura and Makoto Matsumoto.
55
+ * This is a faster version by taking Shawn Cokus's optimization,
56
+ * Matthe Bellew's simplification, Isaku Wada's real version.
57
+ *
58
+ * Before using, initialize the state by using init_genrand(seed)
59
+ * or init_by_array(init_key, key_length).
60
+ *
61
+ * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
62
+ * All rights reserved.
63
+ *
64
+ * Redistribution and use in source and binary forms, with or without
65
+ * modification, are permitted provided that the following conditions
66
+ * are met:
67
+ *
68
+ * 1. Redistributions of source code must retain the above copyright
69
+ * notice, this list of conditions and the following disclaimer.
70
+ *
71
+ * 2. Redistributions in binary form must reproduce the above copyright
72
+ * notice, this list of conditions and the following disclaimer in the
73
+ * documentation and/or other materials provided with the distribution.
74
+ *
75
+ * 3. The names of its contributors may not be used to endorse or promote
76
+ * products derived from this software without specific prior written
77
+ * permission.
78
+ *
79
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
80
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
81
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
82
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
83
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
84
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
85
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
86
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
87
+ * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
88
+ * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
89
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
90
+ *
91
+ *
92
+ * Any feedback is very welcome.
93
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
94
+ * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
95
+ */
96
+
97
+ /**
98
+ * mt19937_data_pod is used to get POD data in and out
99
+ * of mt19937_engine. Used in torch.get_rng_state and
100
+ * torch.set_rng_state functions.
101
+ */
102
+ struct mt19937_data_pod {
103
+ uint64_t seed_;
104
+ int left_;
105
+ bool seeded_;
106
+ uint32_t next_;
107
+ std::array<uint32_t, MERSENNE_STATE_N> state_;
108
+ };
109
+
110
+ class mt19937_engine {
111
+ public:
112
+
113
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
114
+ inline explicit mt19937_engine(uint64_t seed = 5489) {
115
+ init_with_uint32(seed);
116
+ }
117
+
118
+ inline mt19937_data_pod data() const {
119
+ return data_;
120
+ }
121
+
122
+ inline void set_data(const mt19937_data_pod& data) {
123
+ data_ = data;
124
+ }
125
+
126
+ inline uint64_t seed() const {
127
+ return data_.seed_;
128
+ }
129
+
130
+ inline bool is_valid() {
131
+ if ((data_.seeded_ == true)
132
+ && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
133
+ && (data_.next_ <= MERSENNE_STATE_N)) {
134
+ return true;
135
+ }
136
+ return false;
137
+ }
138
+
139
+ inline uint32_t operator()() {
140
+ if (--(data_.left_) == 0) {
141
+ next_state();
142
+ }
143
+ uint32_t y = *(data_.state_.data() + data_.next_++);
144
+ y ^= (y >> 11);
145
+ y ^= (y << 7) & 0x9d2c5680;
146
+ y ^= (y << 15) & 0xefc60000;
147
+ y ^= (y >> 18);
148
+
149
+ return y;
150
+ }
151
+
152
+ private:
153
+ mt19937_data_pod data_;
154
+
155
+ inline void init_with_uint32(uint64_t seed) {
156
+ data_.seed_ = seed;
157
+ data_.seeded_ = true;
158
+ data_.state_[0] = seed & 0xffffffff;
159
+ for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
160
+ data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
161
+ }
162
+ data_.left_ = 1;
163
+ data_.next_ = 0;
164
+ }
165
+
166
+ inline uint32_t mix_bits(uint32_t u, uint32_t v) {
167
+ return (u & UMASK) | (v & LMASK);
168
+ }
169
+
170
+ inline uint32_t twist(uint32_t u, uint32_t v) {
171
+ return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
172
+ }
173
+
174
+ inline void next_state() {
175
+ uint32_t* p = data_.state_.data();
176
+ data_.left_ = MERSENNE_STATE_N;
177
+ data_.next_ = 0;
178
+
179
+ for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
180
+ *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
181
+ }
182
+
183
+ for(int j = MERSENNE_STATE_M; --j; p++) {
184
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
185
+ }
186
+
187
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
188
+ }
189
+
190
+ };
191
+
192
+ typedef mt19937_engine mt19937;
193
+
194
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Dimname.h>
4
+ #include <c10/core/TensorImpl.h>
5
+
6
+ namespace at {
7
+
8
+ class TensorBase;
9
+
10
+ // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
11
+ // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
12
+ // so we have a couple of workarounds.
13
+ //
14
+ // In the long term, we'll move Dimname to c10 and everything in this file
15
+ // can be refactored out. The main blocker for that is that "c10::Symbol"
16
+ // actually exists outside of c10 and needs to be moved in.
17
+
18
+ // TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
19
+ // XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl.
20
+ //
21
+ // This class has an important invariant: there must be at least ONE
22
+ // non-wildcard
23
+ struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
24
+ // This enum is to remind people that the invariant on constructors is that
25
+ // the list of dimnames must have at least one non-wildcard
26
+ enum HAS_NON_WILDCARD {
27
+ HasNonWildcard
28
+ };
29
+
30
+ explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
31
+ : names_(names.vec()) {
32
+ check_invariants();
33
+ }
34
+ explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
35
+ : names_(std::move(names)) {
36
+ check_invariants();
37
+ }
38
+
39
+ std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
40
+ return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
41
+ }
42
+
43
+ DimnameList names() const { return names_; }
44
+
45
+ // Used for an assertion in TensorImpl.h
46
+ int64_t slow_dim() const override {
47
+ return static_cast<int64_t>(names_.size());
48
+ }
49
+
50
+ void check_invariants() const {
51
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
52
+ std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
53
+ }
54
+
55
+ void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
56
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
57
+ std::copy(new_names.begin(), new_names.end(), names_.begin());
58
+ check_invariants();
59
+ }
60
+
61
+ void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
62
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
63
+ names_ = std::move(new_names);
64
+ check_invariants();
65
+ }
66
+
67
+ // INVARIANT: at least one Dimname is non-WILDCARD
68
+ std::vector<Dimname> names_;
69
+ };
70
+
71
+ // When NamesMode is disabled, then all operations ignore tensors' names fields.
72
+ // Concretely speaking, all tensors are treated as having nullopt names.
73
+ struct TORCH_API NamesMode {
74
+ static bool is_enabled();
75
+ static void set_enabled(bool enabled);
76
+ };
77
+
78
+
79
+ // A RAII, thread local (!) guard that enables or disables names upon
80
+ // construction, and sets it back to the original value upon destruction.
81
+ struct TORCH_API NoNamesGuard {
82
+ NoNamesGuard() : prev_mode(NamesMode::is_enabled()) {
83
+ NamesMode::set_enabled(false);
84
+ }
85
+ NoNamesGuard(const NoNamesGuard&) = delete;
86
+ NoNamesGuard(NoNamesGuard&&) = delete;
87
+ NoNamesGuard& operator=(const NoNamesGuard&) = delete;
88
+ NoNamesGuard& operator=(NoNamesGuard&&) = delete;
89
+ ~NoNamesGuard() {
90
+ if (initialized) {
91
+ reset();
92
+ }
93
+ }
94
+ void reset() {
95
+ TORCH_INTERNAL_ASSERT(initialized);
96
+ NamesMode::set_enabled(prev_mode);
97
+ }
98
+ private:
99
+ bool prev_mode;
100
+ bool initialized{true};
101
+ };
102
+
103
+ void check_names_valid_for(const TensorBase& tensor, DimnameList names);
104
+ void check_names_valid_for(size_t tensor_dim, DimnameList names);
105
+
106
+ // Sets the names of `tensor` to be `names`.
107
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names);
108
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
109
+
110
+ constexpr size_t kMaxNamedTensorDim = 64;
111
+
112
+ DimnameList default_names(size_t len);
113
+
114
+ namespace impl {
115
+
116
+ // Some helper functions on TensorImpl. Useful for working with names in TH.
117
+ // XXX: Ideally these would exist as methods on TensorImpl
118
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names);
119
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
120
+
121
+ void check_names_valid_for(TensorImpl* impl, DimnameList names);
122
+
123
+ // Returns true if the tensor's names exist and are not all 'None'.
124
+ // Returns false if the tensor's names don't exist (were not allocated),
125
+ // or if all names are 'None'.
126
+ // We treat not-allocated-names the same as allocated names that are all 'None'.
127
+ TORCH_API bool has_names(const TensorImpl* impl);
128
+
129
+ // Returns the names of the tensor's dimensions.
130
+ // Unnamed tensors are treated as having 'None' in all dimension; this method
131
+ // would return a DimnameList of all 'None's for an unnamed tensor.
132
+ TORCH_API DimnameList get_names(const TensorImpl* impl);
133
+
134
+ // This is more of an implementation detail; one should use impl::get_names /
135
+ // Tensor::names() whenever possible because it provides a cleaner API.
136
+ // Returns the names of the tensor if they have been allocated; returns nullopt
137
+ // instead if the haven't been. The names of a tensor are not allocated if a
138
+ // tensor is constructed with names=None.
139
+ TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl);
140
+
141
+ } // namespace impl
142
+
143
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ConstantSymNodeImpl.h>
4
+ #include <c10/core/SymNodeImpl.h>
5
+ #include <c10/macros/Export.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/intrusive_ptr.h>
8
+ #include <cstdint>
9
+ #include <optional>
10
+ #include <string>
11
+
12
+ namespace c10 {
13
+
14
+ // The motivating usecase for this is to represent the ragged size structure
15
+ // of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
16
+ // allows us to simply return [B, j0, D] if someone queries for the size of our
17
+ // tensor.
18
+ //
19
+ // Morally we define comparison between two nested ints to return true if
20
+ // that comparison holds for all corresponding elements of the arrays they
21
+ // represent. Comparison between a nested int and a plain int is defined
22
+ // similarly.
23
+ //
24
+ // To simulate this desired behavior but also avoid the O(N) cost of checking,
25
+ // we associate each raggedness pattern with an integer "id" that can be used as
26
+ // a proxy to evaluate equality. We also constrain the range of values for this
27
+ // as to enable inequality checks.
28
+ //
29
+ // We also support a positive integer scalar "coeff" that is used for computing
30
+ // strides. For example given, a [B, j0, D] tensor, it can be strided in two
31
+ // different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
32
+ // differentiate the two cases.
33
+ //
34
+ // During tracing the strides of the outputs need to be a function of the size
35
+ // and strides of the inputs so it is important that NestedIntSymNode itself is
36
+ // able to express this.
37
+ class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
38
+ public:
39
+ // CAUTION: you should probably not be constructing these directly; please
40
+ // the higher-level API in python instead (TODO: actually introduce that).
41
+ explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
42
+ : val_(val), coeff_(coeff) {}
43
+
44
+ bool bool_() override {
45
+ return false;
46
+ }
47
+
48
+ bool is_int() override {
49
+ return true;
50
+ }
51
+
52
+ bool is_float() override {
53
+ return false;
54
+ }
55
+
56
+ bool is_bool() override {
57
+ return false;
58
+ }
59
+
60
+ bool is_nested_int() const override {
61
+ return true;
62
+ }
63
+
64
+ bool has_hint() override {
65
+ return true;
66
+ }
67
+
68
+ c10::SymNode wrap_int(int64_t num) override {
69
+ return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
70
+ }
71
+
72
+ int64_t guard_int(const char* file, int64_t line) override {
73
+ TORCH_CHECK(false);
74
+ }
75
+
76
+ double guard_float(const char* file, int64_t line) override {
77
+ TORCH_CHECK(false, "not a float");
78
+ }
79
+
80
+ bool guard_bool(const char* file, int64_t line) override {
81
+ TORCH_CHECK(false, "not a bool");
82
+ }
83
+
84
+ int64_t int_() override {
85
+ TORCH_CHECK(false);
86
+ }
87
+
88
+ std::string str() override {
89
+ if (coeff_ == 1) {
90
+ return "j" + std::to_string(val_);
91
+ }
92
+ return std::to_string(coeff_) + "*j" + std::to_string(val_);
93
+ }
94
+
95
+ // NOTE [ Inequalities with nested int ]
96
+ //
97
+ // The semantics of nested int when it comes to relations is that it is
98
+ // treated as integer known to be within a certain range,
99
+ //
100
+ // j0 \in [2, int64_t::max]
101
+ //
102
+ // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
103
+ // This is a useful default range for the raggedness pattern of a jagged
104
+ // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
105
+ // specialization checks.
106
+ //
107
+ // [ Indeterminate inequalities error out ]
108
+ //
109
+ // Given the semantic defined above, certain relations like j0 < 3 are thus
110
+ // indeterminable. In our impl today, evaluating such relations error
111
+ //
112
+ // It may seem convenient to just define indeterminate relations to return
113
+ // False, but the implementation we maintain in parallel using sympy does not
114
+ // allow this.
115
+ //
116
+ // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
117
+ // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
118
+ // would mean that means that if we define the indeterminate j0 >= 3 to be
119
+ // False, the also indeterminate j0 < 3 will be evaluated to be True!
120
+ //
121
+ // [ Coefficient are assumed positive ]
122
+ //
123
+ // For the purpose of computing inequalities, we consider the coefficient of
124
+ // the nested int to be a positive integer.
125
+ //
126
+ // Thus, no modifications are needed to the logic since
127
+ // j0 >= k implies coeff * j0 >= k
128
+ //
129
+ c10::SymNode eq(const c10::SymNode& other) override;
130
+ c10::SymNode ne(const c10::SymNode& other) override;
131
+ c10::SymNode ge(const c10::SymNode& other) override;
132
+ c10::SymNode gt(const c10::SymNode& other) override;
133
+ c10::SymNode lt(const c10::SymNode& other) override;
134
+ c10::SymNode le(const c10::SymNode& other) override;
135
+ c10::SymNode mul(const c10::SymNode& other) override;
136
+
137
+ std::optional<int64_t> nested_int() override {
138
+ return val_;
139
+ }
140
+
141
+ std::optional<int64_t> nested_int_coeff() override {
142
+ return coeff_;
143
+ }
144
+
145
+ bool is_symbolic() override {
146
+ return false;
147
+ }
148
+
149
+ c10::SymNode clone() override;
150
+
151
+ #define DEFINE_BINARY_NOT_SUPPORTED(name) \
152
+ c10::SymNode name(const c10::SymNode& other) override { \
153
+ TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
154
+ }
155
+
156
+ DEFINE_BINARY_NOT_SUPPORTED(add)
157
+ DEFINE_BINARY_NOT_SUPPORTED(sub)
158
+ DEFINE_BINARY_NOT_SUPPORTED(truediv)
159
+ DEFINE_BINARY_NOT_SUPPORTED(pow)
160
+ DEFINE_BINARY_NOT_SUPPORTED(floordiv)
161
+ DEFINE_BINARY_NOT_SUPPORTED(mod)
162
+ DEFINE_BINARY_NOT_SUPPORTED(sym_min)
163
+ DEFINE_BINARY_NOT_SUPPORTED(sym_max)
164
+ DEFINE_BINARY_NOT_SUPPORTED(sym_and)
165
+ DEFINE_BINARY_NOT_SUPPORTED(sym_or)
166
+
167
+ #undef DEFINE_BINARY_NOT_SUPPORTED
168
+
169
+ #define DEFINE_NOT_SUPPORTED(name) \
170
+ c10::SymNode name() override { \
171
+ TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
172
+ }
173
+
174
+ DEFINE_NOT_SUPPORTED(sym_not)
175
+ DEFINE_NOT_SUPPORTED(ceil)
176
+ DEFINE_NOT_SUPPORTED(floor)
177
+ DEFINE_NOT_SUPPORTED(neg)
178
+ DEFINE_NOT_SUPPORTED(sym_float)
179
+
180
+ #undef DEFINE_NOT_SUPPORTED
181
+
182
+ private:
183
+ int64_t val_;
184
+ int64_t coeff_;
185
+ };
186
+
187
+ } // namespace c10
.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // define constants like M_PI and C keywords for MSVC
4
+ #ifdef _MSC_VER
5
+ #define _USE_MATH_DEFINES
6
+ #include <math.h>
7
+ #endif
8
+
9
+
10
+ #ifdef __CUDACC__
11
+ #include <cuda.h>
12
+ #endif
13
+
14
+ #include <array>
15
+ #include <c10/macros/Macros.h>
16
+ #include <cmath>
17
+ #include <cstdint>
18
+
19
+ namespace at {
20
+
21
+ // typedefs for holding vector data
22
+ namespace detail {
23
+
24
+ typedef std::array<uint32_t, 4> UINT4;
25
+ typedef std::array<uint32_t, 2> UINT2;
26
+ typedef std::array<double, 2> DOUBLE2;
27
+ typedef std::array<float, 2> FLOAT2;
28
+
29
+ } // namespace detail
30
+
31
+ /**
32
+ * Note [Philox Engine implementation]
33
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
34
+ * Originally implemented in PyTorch's fusion compiler
35
+ * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
36
+ * for details regarding the engine.
37
+ *
38
+ * Note that currently this implementation of the philox engine is not used
39
+ * anywhere except for tests in cpu_generator_test.cpp. However, this engine
40
+ * will replace curandStatePhilox4_32_10_t in the future.
41
+ *
42
+ * The philox engine takes a seed value, a subsequeunce
43
+ * for starting the generation and an offset for the subsequence.
44
+ * Think of this engine as an algorithm producing a huge array. We are
45
+ * parallelizing this array by partitioning the huge array and assigning
46
+ * a thread index to each partition. In other words, each seed value
47
+ * (there are 2^64 possible seed values) gives a sub array of size
48
+ * 2^128 (each element in that array is a 128 bit number). Reasoning
49
+ * behind the array being of size 2^128 is, there are 2^64 possible
50
+ * thread index value and there is an array of size 2^64 for each of
51
+ * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
52
+ *
53
+ * In short, this generator can produce 2^64 (seed values) * 2^128 (number
54
+ * of elements in an array given by a seed value) = 2^192 values.
55
+ *
56
+ * Arguments:
57
+ * seed: Seed values could be any number from 0 to 2^64-1.
58
+ * subsequence: Subsequence is just the cuda thread indexing with:
59
+ * - blockIdx.x * blockDim.x + threadIdx.x
60
+ * offset: The offset variable in PhiloxEngine decides how many 128-bit
61
+ * random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
62
+ * and hence really decides the total number of randoms that can be achieved
63
+ * for the given subsequence.
64
+ */
65
+
66
+ class philox_engine {
67
+ public:
68
+
69
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
70
+ C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
71
+ uint64_t subsequence = 0,
72
+ uint64_t offset = 0) {
73
+
74
+ reset_state(seed, subsequence);
75
+ incr_n(offset);
76
+ }
77
+
78
+ C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
79
+ uint64_t subsequence = 0) {
80
+ key_[0] = static_cast<uint32_t>(seed);
81
+ key_[1] = static_cast<uint32_t>(seed >> 32);
82
+ counter_ = detail::UINT4{};
83
+ counter_[2] = static_cast<uint32_t>(subsequence);
84
+ counter_[3] = static_cast<uint32_t>(subsequence >> 32);
85
+ STATE = 0;
86
+ }
87
+
88
+ /**
89
+ * Set the offset field of Philox Generator to the desired offset.
90
+ */
91
+ C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
92
+ counter_[0] = static_cast<uint32_t>(offset);
93
+ counter_[1] = static_cast<uint32_t>(offset >> 32);
94
+ }
95
+
96
+ /**
97
+ * Gets the current offset of the Philox Generator.
98
+ */
99
+ C10_HOST_DEVICE uint64_t get_offset() const {
100
+ uint64_t lo = static_cast<uint64_t>(counter_[0]);
101
+ uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
102
+ return lo | hi;
103
+ }
104
+
105
+ /**
106
+ * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
107
+ */
108
+ C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
109
+ if(STATE == 0) {
110
+ detail::UINT4 counter = counter_;
111
+ detail::UINT2 key = key_;
112
+ output_ = rand(counter, key, n_rounds);
113
+ incr();
114
+ }
115
+ uint32_t ret = output_[static_cast<int>(STATE)];
116
+ STATE = (STATE + 1) & 3;
117
+ return ret;
118
+ }
119
+
120
+ inline float randn(uint32_t n_rounds) {
121
+ #ifdef __CUDA_ARCH__
122
+ AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
123
+ #endif
124
+ if(STATE == 0) {
125
+ detail::UINT4 counter = counter_;
126
+ detail::UINT2 key = key_;
127
+ output_ = rand(counter, key, n_rounds);
128
+ incr();
129
+ }
130
+ // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
131
+ // TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
132
+ float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
133
+ float u2 = 1 - uint32_to_uniform_float(output_[1]);
134
+ return static_cast<float>(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
135
+ }
136
+
137
+ /**
138
+ * Function that Skips N 128 bit numbers in a subsequence
139
+ */
140
+ C10_HOST_DEVICE inline void incr_n(uint64_t n) {
141
+ uint32_t nlo = static_cast<uint32_t>(n);
142
+ uint32_t nhi = static_cast<uint32_t>(n >> 32);
143
+ counter_[0] += nlo;
144
+ // if overflow in x has occurred, carry over to nhi
145
+ if (counter_[0] < nlo) {
146
+ nhi++;
147
+ // if overflow in nhi has occurred during carry over,
148
+ // propagate that overflow to y and exit to increment z
149
+ // otherwise return
150
+ counter_[1] += nhi;
151
+ if(nhi != 0) {
152
+ if (nhi <= counter_[1]) {
153
+ return;
154
+ }
155
+ }
156
+ } else {
157
+ // if overflow in y has occurred during addition,
158
+ // exit to increment z
159
+ // otherwise return
160
+ counter_[1] += nhi;
161
+ if (nhi <= counter_[1]) {
162
+ return;
163
+ }
164
+ }
165
+ if (++counter_[2])
166
+ return;
167
+ ++counter_[3];
168
+ }
169
+
170
+ /**
171
+ * Function that Skips one 128 bit number in a subsequence
172
+ */
173
+ C10_HOST_DEVICE inline void incr() {
174
+ if (++counter_[0])
175
+ return;
176
+ if (++counter_[1])
177
+ return;
178
+ if (++counter_[2]) {
179
+ return;
180
+ }
181
+ ++counter_[3];
182
+ }
183
+
184
+ private:
185
+ detail::UINT4 counter_;
186
+ detail::UINT4 output_;
187
+ detail::UINT2 key_;
188
+ uint32_t STATE;
189
+
190
+ C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
191
+ uint32_t *result_high) {
192
+ #ifdef __CUDA_ARCH__
193
+ *result_high = __umulhi(a, b);
194
+ return a*b;
195
+ #else
196
+ const uint64_t product = static_cast<uint64_t>(a) * b;
197
+ *result_high = static_cast<uint32_t>(product >> 32);
198
+ return static_cast<uint32_t>(product);
199
+ #endif
200
+ }
201
+
202
+ C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
203
+ uint32_t hi0 = 0;
204
+ uint32_t hi1 = 0;
205
+ uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
206
+ uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
207
+ detail::UINT4 ret;
208
+ ret[0] = hi1 ^ ctr[1] ^ in_key[0];
209
+ ret[1] = lo1;
210
+ ret[2] = hi0 ^ ctr[3] ^ in_key[1];
211
+ ret[3] = lo0;
212
+ return ret;
213
+ }
214
+
215
+ C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
216
+ // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
217
+ constexpr float scale = 4.6566127342e-10;
218
+ return static_cast<float>(value & 0x7FFFFFFF) * scale;
219
+ }
220
+
221
+
222
+
223
+ C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
224
+ for (uint32_t round = 0; round < (n_rounds - 1); round++) {
225
+ counter = single_round(counter, key);
226
+ key[0] += (kPhilox10A); key[1] += (kPhilox10B);
227
+ }
228
+ return single_round(counter, key);
229
+ }
230
+
231
+
232
+ static const uint32_t kPhilox10A = 0x9E3779B9;
233
+ static const uint32_t kPhilox10B = 0xBB67AE85;
234
+ static const uint32_t kPhiloxSA = 0xD2511F53;
235
+ static const uint32_t kPhiloxSB = 0xCD9E8D57;
236
+ };
237
+
238
+ typedef philox_engine Philox4_32;
239
+
240
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TorchDispatchUtils.h>
3
+
4
+
5
+ namespace at::impl {
6
+
7
+ struct TORCH_API RestorePythonTLSSnapshot {
8
+ RestorePythonTLSSnapshot();
9
+ RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete;
10
+ RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete;
11
+ RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete;
12
+ RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete;
13
+ ~RestorePythonTLSSnapshot();
14
+
15
+ private:
16
+ c10::impl::LocalDispatchKeySet saved_;
17
+ c10::impl::ForceDispatchKeyGuard guard_;
18
+ };
19
+
20
+
21
+ // RAII guard to make working with the above TLS safer.
22
+ struct TORCH_API MaybeSetTLSOnEntryGuard {
23
+ public:
24
+ MaybeSetTLSOnEntryGuard();
25
+ MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete;
26
+ MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete;
27
+ MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete;
28
+ MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete;
29
+ ~MaybeSetTLSOnEntryGuard();
30
+
31
+ private:
32
+ bool value_set_;
33
+ };
34
+
35
+ } // namespace at::impl
.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/dispatch/Dispatcher.h>
4
+
5
+ // TODO: this can probably live in c10
6
+
7
+
8
+ namespace at::impl {
9
+
10
+ class TORCH_API PythonOpRegistrationTrampoline final {
11
+ static std::atomic<c10::impl::PyInterpreter*> interpreter_;
12
+
13
+ public:
14
+ // Returns true if you successfully registered yourself (that means
15
+ // you are in the hot seat for doing the operator registrations!)
16
+ static bool registerInterpreter(c10::impl::PyInterpreter*);
17
+
18
+ // Returns nullptr if no interpreter has been registered yet.
19
+ static c10::impl::PyInterpreter* getInterpreter();
20
+ };
21
+
22
+ } // namespace at::impl
.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+ #include <c10/core/QScheme.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace at {
8
+
9
+ class Tensor;
10
+ struct QTensorImpl;
11
+ struct Quantizer;
12
+ using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
13
+ using QuantizerPtr = c10::intrusive_ptr<Quantizer>;
14
+
15
+ /**
16
+ * Quantizer is the class for storing all the information
17
+ * that's necessary to perform quantize and dequantize
18
+ * operation.
19
+ *
20
+ * We might have different types of quantization schemes and this is
21
+ * the base class for all quantizers.
22
+ *
23
+ * QTensorImpl will hold a pointer to Quantizer so that we can support
24
+ * different quantization schemes on Tensor.
25
+ *
26
+ * For example, the most common quantization scheme, Affine Quantization,
27
+ * requires scale and zero_point as parameters, we'll store scale and zero_point
28
+ * inside the instance and we can use it to quantize a float Tensor or
29
+ * dequantize a quantized Tensor.
30
+ *
31
+ * When you add new types of leaf Quantizer class, please also
32
+ * make sure to add a corresponding QScheme enum since
33
+ * they should have one to one mapping.
34
+ *
35
+ * Note about intrusive_ptr:
36
+ * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
37
+ * share the same Quantizer. Quantizer should be immutable.
38
+ */
39
+ struct TORCH_API Quantizer : public c10::intrusive_ptr_target {
40
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
41
+ const ScalarType scalar_type_;
42
+ explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {}
43
+ ~Quantizer() override = default;
44
+
45
+ // Copied from torch/csrc/jit/ir/scope.h
46
+ QuantizerPtr intrusive_from_this() {
47
+ c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
48
+ // from a raw `this` pointer
49
+ // so we need to bump the refcount
50
+ // to account for this ownership
51
+ return c10::intrusive_ptr<Quantizer>::reclaim(this);
52
+ }
53
+
54
+ /**
55
+ * Each concrete Quantizer type should have a unique QScheme type.
56
+ */
57
+ virtual QScheme qscheme() const = 0;
58
+
59
+ ScalarType scalar_type() const {
60
+ return scalar_type_;
61
+ }
62
+
63
+ /**
64
+ * quantize a float Tensor into a quantized Tensor.
65
+ */
66
+ virtual Tensor quantize(const Tensor& t) = 0;
67
+
68
+ /**
69
+ * dequantize a quantized Tensor into a float Tensor.
70
+ */
71
+ virtual Tensor dequantize(const Tensor& t) = 0;
72
+
73
+ /**
74
+ * dequantize a quantized Tensor into a float Tensor, out= variant
75
+ */
76
+ virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0;
77
+
78
+ /**
79
+ * Compare against `other` for equality.
80
+ */
81
+ virtual bool equalTo(QuantizerPtr other) const = 0;
82
+ };
83
+
84
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <iosfwd>
5
+
6
+ namespace at {
7
+
8
+ struct Range {
9
+ Range(int64_t begin, int64_t end)
10
+ : begin(begin)
11
+ , end(end) {}
12
+
13
+ int64_t size() const { return end - begin; }
14
+
15
+ Range operator/(int64_t divisor) {
16
+ return Range(begin / divisor, end / divisor);
17
+ }
18
+
19
+ int64_t begin;
20
+ int64_t end;
21
+ };
22
+
23
+ std::ostream& operator<<(std::ostream& out, const Range& range);
24
+
25
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::Reduction {
4
+
5
+ // NB: Keep this in sync with Reduction class in torch/nn/_reduction.py
6
+ // These constants control the reduction behavior of loss functions.
7
+ // Ideally, this would be a scoped enum, but jit doesn't support that
8
+ enum Reduction {
9
+ None, // Do not reduce
10
+ Mean, // (Possibly weighted) mean of losses
11
+ Sum, // Sum losses
12
+ END
13
+ };
14
+ } // namespace at::Reduction
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/Scalar.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/ScalarType.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TensorBody.h>
4
+ #include <c10/util/Exception.h>
5
+
6
+ namespace at {
7
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
8
+ class TORCH_API OptionalTensorRef {
9
+ public:
10
+ OptionalTensorRef() = default;
11
+
12
+ ~OptionalTensorRef() {
13
+ ref_.unsafeReleaseTensorImpl();
14
+ }
15
+
16
+ OptionalTensorRef(const TensorBase& src)
17
+ : ref_(Tensor::unsafe_borrow_t{}, src) {
18
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
19
+ }
20
+
21
+ OptionalTensorRef(const OptionalTensorRef& rhs)
22
+ : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
23
+
24
+ OptionalTensorRef(OptionalTensorRef&& rhs) = default;
25
+ OptionalTensorRef& operator=(OptionalTensorRef rhs) {
26
+ std::swap(ref_, rhs.ref_);
27
+ return *this;
28
+ }
29
+
30
+ bool has_value() const {
31
+ return ref_.defined();
32
+ }
33
+
34
+ const Tensor& getTensorRef() const & {
35
+ return ref_;
36
+ }
37
+
38
+ const Tensor& operator*() const & {
39
+ return ref_;
40
+ }
41
+
42
+ const Tensor* operator->() const & {
43
+ return &ref_;
44
+ }
45
+
46
+ operator bool() const {
47
+ return ref_.defined();
48
+ }
49
+
50
+ private:
51
+ Tensor ref_;
52
+ };
53
+
54
+ // Use to convert a TensorBase (that may be undefined) to an at::Tensor
55
+ // without bumping refcount.
56
+ class TORCH_API TensorRef {
57
+ public:
58
+ ~TensorRef() {
59
+ ref_.unsafeReleaseTensorImpl();
60
+ }
61
+
62
+ TensorRef(const TensorBase& src)
63
+ : ref_(Tensor::unsafe_borrow_t{}, src) {}
64
+ TensorRef(TensorRef&& other) = default;
65
+ TensorRef(const TensorRef&) = default;
66
+ TensorRef& operator=(const TensorRef&) = default;
67
+ TensorRef& operator=(TensorRef&&) = default;
68
+
69
+ const Tensor& operator*() const & {
70
+ return ref_;
71
+ }
72
+ private:
73
+ Tensor ref_;
74
+ };
75
+
76
+ template <typename T>
77
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
78
+ // Return the grad argument in case of a hook with void return type to have an
79
+ // std::function with Tensor return type
80
+ static_assert(std::is_same_v<decltype(hook(Tensor())), void>,
81
+ "Expected hook to return void");
82
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
83
+ TensorRef grad(grad_base);
84
+ fn(*grad);
85
+ return Tensor();
86
+ });
87
+ }
88
+
89
+ template <typename T>
90
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
91
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
92
+ TensorRef grad(grad_base);
93
+ Tensor ret = fn(*grad);
94
+ return TensorBase(std::move(ret));
95
+ });
96
+ }
97
+
98
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <c10/util/Deprecated.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/irange.h>
8
+ #include <cstddef>
9
+ #include <cstdint>
10
+ #include <type_traits>
11
+
12
+ namespace at {
13
+
14
+ // The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
15
+ // is used to enable the __restrict__ keyword/modifier for the data
16
+ // passed to cuda.
17
+ template <typename T>
18
+ struct DefaultPtrTraits {
19
+ typedef T* PtrType;
20
+ };
21
+
22
+ #if defined(__CUDACC__) || defined(__HIPCC__)
23
+ template <typename T>
24
+ struct RestrictPtrTraits {
25
+ typedef T* __restrict__ PtrType;
26
+ };
27
+ #endif
28
+
29
+ // TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
30
+ // For CUDA tensors it is used in device code (only). This means that we restrict ourselves
31
+ // to functions and types available there (e.g. IntArrayRef isn't).
32
+
33
+ // The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
34
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
35
+ class TensorAccessorBase {
36
+ public:
37
+ typedef typename PtrTraits<T>::PtrType PtrType;
38
+
39
+ C10_HOST_DEVICE TensorAccessorBase(
40
+ PtrType data_,
41
+ const index_t* sizes_,
42
+ const index_t* strides_)
43
+ : data_(data_), sizes_(sizes_), strides_(strides_) {}
44
+ C10_HOST IntArrayRef sizes() const {
45
+ return IntArrayRef(sizes_,N);
46
+ }
47
+ C10_HOST IntArrayRef strides() const {
48
+ return IntArrayRef(strides_,N);
49
+ }
50
+ C10_HOST_DEVICE index_t stride(index_t i) const {
51
+ return strides_[i];
52
+ }
53
+ C10_HOST_DEVICE index_t size(index_t i) const {
54
+ return sizes_[i];
55
+ }
56
+ C10_HOST_DEVICE PtrType data() {
57
+ return data_;
58
+ }
59
+ C10_HOST_DEVICE const PtrType data() const {
60
+ return data_;
61
+ }
62
+ protected:
63
+ PtrType data_;
64
+ const index_t* sizes_;
65
+ const index_t* strides_;
66
+ };
67
+
68
+ // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
69
+ // `Tensor.accessor<T, N>()`.
70
+ // For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
71
+ // indexing on the device uses `TensorAccessor`s.
72
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
73
+ class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
74
+ public:
75
+ typedef typename PtrTraits<T>::PtrType PtrType;
76
+
77
+ C10_HOST_DEVICE TensorAccessor(
78
+ PtrType data_,
79
+ const index_t* sizes_,
80
+ const index_t* strides_)
81
+ : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
82
+
83
+ C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
84
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
85
+ }
86
+
87
+ C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
88
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
89
+ }
90
+ };
91
+
92
+ template<typename T, template <typename U> class PtrTraits, typename index_t>
93
+ class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
94
+ public:
95
+ typedef typename PtrTraits<T>::PtrType PtrType;
96
+
97
+ C10_HOST_DEVICE TensorAccessor(
98
+ PtrType data_,
99
+ const index_t* sizes_,
100
+ const index_t* strides_)
101
+ : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
102
+ C10_HOST_DEVICE T & operator[](index_t i) {
103
+ // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
104
+ return this->data_[this->strides_[0]*i];
105
+ }
106
+ C10_HOST_DEVICE const T & operator[](index_t i) const {
107
+ return this->data_[this->strides_[0]*i];
108
+ }
109
+ };
110
+
111
+
112
+ // GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
113
+ // and as
114
+ // In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
115
+ // in order to transfer them on the device when calling kernels.
116
+ // On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
117
+ // Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
118
+ // Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
119
+ // on the device, so those functions are host only.
120
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
121
+ class GenericPackedTensorAccessorBase {
122
+ public:
123
+ typedef typename PtrTraits<T>::PtrType PtrType;
124
+ C10_HOST GenericPackedTensorAccessorBase(
125
+ PtrType data_,
126
+ const index_t* sizes_,
127
+ const index_t* strides_)
128
+ : data_(data_) {
129
+ std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
130
+ std::copy(strides_, strides_ + N, std::begin(this->strides_));
131
+ }
132
+
133
+ // if index_t is not int64_t, we want to have an int64_t constructor
134
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
135
+ C10_HOST GenericPackedTensorAccessorBase(
136
+ PtrType data_,
137
+ const source_index_t* sizes_,
138
+ const source_index_t* strides_)
139
+ : data_(data_) {
140
+ for (const auto i : c10::irange(N)) {
141
+ this->sizes_[i] = sizes_[i];
142
+ this->strides_[i] = strides_[i];
143
+ }
144
+ }
145
+
146
+ C10_HOST_DEVICE index_t stride(index_t i) const {
147
+ return strides_[i];
148
+ }
149
+ C10_HOST_DEVICE index_t size(index_t i) const {
150
+ return sizes_[i];
151
+ }
152
+ C10_HOST_DEVICE PtrType data() {
153
+ return data_;
154
+ }
155
+ C10_HOST_DEVICE const PtrType data() const {
156
+ return data_;
157
+ }
158
+ protected:
159
+ PtrType data_;
160
+ // NOLINTNEXTLINE(*c-arrays*)
161
+ index_t sizes_[N];
162
+ // NOLINTNEXTLINE(*c-arrays*)
163
+ index_t strides_[N];
164
+ C10_HOST void bounds_check_(index_t i) const {
165
+ TORCH_CHECK_INDEX(
166
+ 0 <= i && i < index_t{N},
167
+ "Index ",
168
+ i,
169
+ " is not within bounds of a tensor of dimension ",
170
+ N);
171
+ }
172
+ };
173
+
174
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
175
+ class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
176
+ public:
177
+ typedef typename PtrTraits<T>::PtrType PtrType;
178
+
179
+ C10_HOST GenericPackedTensorAccessor(
180
+ PtrType data_,
181
+ const index_t* sizes_,
182
+ const index_t* strides_)
183
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
184
+
185
+ // if index_t is not int64_t, we want to have an int64_t constructor
186
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
187
+ C10_HOST GenericPackedTensorAccessor(
188
+ PtrType data_,
189
+ const source_index_t* sizes_,
190
+ const source_index_t* strides_)
191
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
192
+
193
+ C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
194
+ index_t* new_sizes = this->sizes_ + 1;
195
+ index_t* new_strides = this->strides_ + 1;
196
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
197
+ }
198
+
199
+ C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
200
+ const index_t* new_sizes = this->sizes_ + 1;
201
+ const index_t* new_strides = this->strides_ + 1;
202
+ return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
203
+ }
204
+
205
+ /// Returns a PackedTensorAccessor of the same dimension after transposing the
206
+ /// two dimensions given. Does not actually move elements; transposition is
207
+ /// made by permuting the size/stride arrays. If the dimensions are not valid,
208
+ /// asserts.
209
+ C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
210
+ index_t dim1,
211
+ index_t dim2) const {
212
+ this->bounds_check_(dim1);
213
+ this->bounds_check_(dim2);
214
+ GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
215
+ this->data_, this->sizes_, this->strides_);
216
+ std::swap(result.strides_[dim1], result.strides_[dim2]);
217
+ std::swap(result.sizes_[dim1], result.sizes_[dim2]);
218
+ return result;
219
+ }
220
+ };
221
+
222
+ template<typename T, template <typename U> class PtrTraits, typename index_t>
223
+ class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
224
+ public:
225
+ typedef typename PtrTraits<T>::PtrType PtrType;
226
+ C10_HOST GenericPackedTensorAccessor(
227
+ PtrType data_,
228
+ const index_t* sizes_,
229
+ const index_t* strides_)
230
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
231
+
232
+ // if index_t is not int64_t, we want to have an int64_t constructor
233
+ template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
234
+ C10_HOST GenericPackedTensorAccessor(
235
+ PtrType data_,
236
+ const source_index_t* sizes_,
237
+ const source_index_t* strides_)
238
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
239
+
240
+ C10_DEVICE T & operator[](index_t i) {
241
+ return this->data_[this->strides_[0] * i];
242
+ }
243
+ C10_DEVICE const T& operator[](index_t i) const {
244
+ return this->data_[this->strides_[0]*i];
245
+ }
246
+
247
+ // Same as in the general N-dimensional case, but note that in the
248
+ // 1-dimensional case the returned PackedTensorAccessor will always be an
249
+ // identical copy of the original
250
+ C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
251
+ index_t dim1,
252
+ index_t dim2) const {
253
+ this->bounds_check_(dim1);
254
+ this->bounds_check_(dim2);
255
+ return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
256
+ this->data_, this->sizes_, this->strides_);
257
+ }
258
+ };
259
+
260
+
261
+ // Can't put this directly into the macro function args because of commas
262
+ #define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
263
+
264
+ // Old name for `GenericPackedTensorAccessor`
265
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
266
+ C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)
267
+
268
+ #undef AT_X
269
+
270
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
271
+ using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;
272
+
273
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
274
+ using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
275
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Device.h>
4
+ #include <c10/core/Layout.h>
5
+ #include <c10/core/MemoryFormat.h>
6
+ #include <c10/core/ScalarType.h>
7
+ #include <c10/core/ScalarTypeToTypeMeta.h>
8
+ #include <c10/core/Storage.h>
9
+ #include <c10/core/SymIntArrayRef.h>
10
+ #include <c10/core/TensorImpl.h>
11
+ #include <c10/core/TensorOptions.h>
12
+ #include <c10/core/UndefinedTensorImpl.h>
13
+ #include <c10/core/WrapDimMinimal.h>
14
+ #include <c10/util/C++17.h>
15
+ #include <c10/util/Exception.h>
16
+ #include <c10/util/ExclusivelyOwned.h>
17
+ #include <c10/util/ExclusivelyOwnedTensorTraits.h>
18
+ #include <c10/util/MaybeOwned.h>
19
+ #include <optional>
20
+ #include <c10/util/intrusive_ptr.h>
21
+
22
+ #include <ATen/core/NamedTensor.h>
23
+ #include <ATen/core/QuantizerBase.h>
24
+ #include <ATen/core/TensorAccessor.h>
25
+ #include <ATen/StorageUtils.h>
26
+
27
+ namespace c10 {
28
+ class Scalar;
29
+ }
30
+
31
+ namespace torch::autograd {
32
+
33
+ struct Node;
34
+
35
+ } // namespace torch::autograd
36
+
37
+ namespace at {
38
+
39
+ class Tensor;
40
+ class TensorBase;
41
+
42
+ // Convert Tensor to TensorBase without any need to include Tensor.h
43
+ TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
44
+
45
+ namespace impl {
46
+ inline bool variable_excluded_from_dispatch() {
47
+ #ifdef C10_MOBILE
48
+ // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
49
+ return true;
50
+ #else
51
+ return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
52
+ #endif
53
+ }
54
+
55
+ }
56
+
57
+ // NOTE: [Tensor vs. TensorBase]
58
+ //
59
+ // Tensor, being the central data structure in PyTorch, gets used and
60
+ // its header included almost everywhere. Unfortunately this means
61
+ // every time an operator signature is updated or changed in
62
+ // native_functions.yaml, you (and every other PyTorch developer) need
63
+ // to recompile all of ATen and its dependencies.
64
+ //
65
+ // TensorBase aims to break up these header dependencies, and improve
66
+ // incremental build times for all PyTorch developers. TensorBase
67
+ // represents a reference counted handle to TensorImpl, exactly the
68
+ // same as Tensor. However, TensorBase doesn't have code generated
69
+ // methods in its API and thus no dependence on native_functions.yaml.
70
+ //
71
+ // Usage tips
72
+ // ----------
73
+ // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
74
+ // or .cu file to ensure it has no header dependencies on
75
+ // native_functions.yaml (direct or indirect).
76
+ // - Tensor inherits from TensorBase, so functions taking
77
+ // `const TensorBase &` are callable with Tensor as well.
78
+ // - TensorBase can be converted to Tensor with `Tensor(tensor_base)`,
79
+ // but this requires a reference-count bump. OptionalTensorRef, on
80
+ // the other hand, can materialize a `const Tensor &` without
81
+ // touching the reference-count.
82
+ class TORCH_API TensorBase {
83
+ public:
84
+ struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
85
+
86
+ protected:
87
+ // Create a Tensor with a +0 reference count. Special care must be
88
+ // taken to avoid decrementing this reference count at destruction
89
+ // time. Intended to support MaybeOwnedTraits<Tensor>.
90
+ explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
91
+ : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {}
92
+ friend MaybeOwnedTraits<TensorBase>;
93
+
94
+ public:
95
+ TensorBase() = default;
96
+ // This constructor should not be used by end users and is an implementation
97
+ // detail invoked by autogenerated code.
98
+ explicit TensorBase(
99
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
100
+ : impl_(std::move(tensor_impl)) {
101
+ if (impl_.get() == nullptr) {
102
+ throw std::runtime_error("TensorImpl with nullptr is not supported");
103
+ }
104
+ }
105
+ TensorBase(const TensorBase&) = default;
106
+ TensorBase(TensorBase&&) noexcept = default;
107
+ ~TensorBase() noexcept = default;
108
+
109
+ public:
110
+ // Creates a new wrapper from TensorImpl. Intentionally a free method because
111
+ // it should be used with care. Checks necessary invariants
112
+ static TensorBase wrap_tensor_impl(
113
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
114
+ TensorBase r(std::move(tensor_impl));
115
+ r.enforce_invariants();
116
+ return r;
117
+ }
118
+
119
+ int64_t dim() const {
120
+ return impl_->dim();
121
+ }
122
+ int64_t storage_offset() const {
123
+ return impl_->storage_offset();
124
+ }
125
+
126
+ TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
127
+ if (is_contiguous(memory_format)) {
128
+ return *this;
129
+ } else {
130
+ return __dispatch_contiguous(memory_format);
131
+ }
132
+ }
133
+
134
+ /// Should be used if *this can reasonably be expected to be contiguous and
135
+ /// performance is important.
136
+ /// Compared to contiguous, it saves a reference count
137
+ /// increment/decrement if *this is already contiguous, at the cost
138
+ /// in all cases of an extra pointer of stack usage, an extra branch
139
+ /// to access, and an extra branch at destruction time.
140
+ c10::MaybeOwned<TensorBase> expect_contiguous(
141
+ MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
142
+
143
+ // Use .contiguous() instead. Trying to borrow from a prvalue
144
+ // will only lead to trouble and dangling references.
145
+ c10::MaybeOwned<TensorBase> expect_contiguous(
146
+ MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
147
+
148
+ const TensorBase& fill_(const c10::Scalar& scalar) const;
149
+ const TensorBase& zero_() const;
150
+
151
+ TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
152
+
153
+ bool is_complex() const {
154
+ return at::isComplexType(this->scalar_type());
155
+ }
156
+
157
+ bool is_floating_point() const {
158
+ return at::isFloatingType(this->scalar_type());
159
+ }
160
+
161
+ bool is_signed() const {
162
+ return at::isSignedType(this->scalar_type());
163
+ }
164
+
165
+ c10::SymInt sym_size(int64_t dim) const {
166
+ return impl_->sym_size(dim);
167
+ }
168
+
169
+ c10::SymInt sym_stride(int64_t dim) const {
170
+ const auto sizes = this->sym_strides();
171
+ const auto ndim = static_cast<int64_t>(sizes.size());
172
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
173
+ return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
174
+
175
+ }
176
+
177
+ int64_t size(int64_t dim) const {
178
+ return impl_->size(dim);
179
+ }
180
+
181
+ int64_t stride(int64_t dim) const {
182
+ const auto strides = this->strides();
183
+ const auto ndim = static_cast<int64_t>(strides.size());
184
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
185
+ return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
186
+ }
187
+
188
+ TensorImpl * unsafeGetTensorImpl() const {
189
+ return impl_.get();
190
+ }
191
+ TensorImpl * unsafeReleaseTensorImpl() {
192
+ return impl_.release();
193
+ }
194
+ const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
195
+ return impl_;
196
+ }
197
+
198
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
199
+ return std::move(impl_);
200
+ }
201
+
202
+ bool defined() const {
203
+ return impl_;
204
+ }
205
+
206
+ void reset() {
207
+ impl_.reset();
208
+ }
209
+
210
+ #if defined (_MSC_VER)
211
+ TensorBase& operator=(const TensorBase& x) & {
212
+ impl_ = x.impl_;
213
+ return *this;
214
+ };
215
+ TensorBase& operator=(TensorBase&& x) & noexcept {
216
+ impl_ = std::move(x.impl_);
217
+ return *this;
218
+ }
219
+ #else
220
+ TensorBase& operator=(const TensorBase& x) & = default;
221
+ TensorBase& operator=(TensorBase&& x) & noexcept = default;
222
+ #endif
223
+
224
+ // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
225
+ TensorBase& operator=(const TensorBase&) && = delete;
226
+ TensorBase& operator=(TensorBase&&) && noexcept = delete;
227
+
228
+ bool is_same(const TensorBase& other) const noexcept {
229
+ return impl_ == other.impl_;
230
+ }
231
+ size_t use_count() const noexcept {
232
+ return impl_.use_count();
233
+ }
234
+ size_t weak_use_count() const noexcept {
235
+ return impl_.weak_use_count();
236
+ }
237
+
238
+ std::string toString() const;
239
+
240
+ IntArrayRef sizes() const {
241
+ return impl_->sizes();
242
+ }
243
+ c10::SymIntArrayRef sym_sizes() const {
244
+ return impl_->sym_sizes();
245
+ }
246
+ c10::SymIntArrayRef sym_strides() const {
247
+ return impl_->sym_strides();
248
+ }
249
+ IntArrayRef strides() const {
250
+ return impl_->strides();
251
+ }
252
+ // See impl::get_opt_names in ATen/NamedTensor.h for docs.
253
+ std::optional<DimnameList> opt_names() const {
254
+ return impl::get_opt_names(unsafeGetTensorImpl());
255
+ }
256
+ // See impl::get_names in ATen/NamedTensor.h for docs.
257
+ DimnameList names() const {
258
+ return impl::get_names(unsafeGetTensorImpl());
259
+ }
260
+ int64_t ndimension() const {
261
+ return dim();
262
+ }
263
+
264
+ bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
265
+ return impl_->is_contiguous(memory_format);
266
+ }
267
+
268
+ bool is_non_overlapping_and_dense() const {
269
+ return impl_->is_non_overlapping_and_dense();
270
+ }
271
+
272
+ at::MemoryFormat suggest_memory_format(
273
+ bool channels_last_strides_exact_match = false) const {
274
+ // Setting channels_last_strides_exact_match to true forces function to
275
+ // check 0,1 - sized dimension strides.
276
+ if (layout() == at::kStrided) {
277
+ if (impl_->is_strides_like_channels_last()) {
278
+ if (!channels_last_strides_exact_match ||
279
+ get_channels_last_strides_2d(sizes()) == strides()) {
280
+ return at::MemoryFormat::ChannelsLast;
281
+ }
282
+ }
283
+ else if (impl_->is_strides_like_channels_last_3d()) {
284
+ if (!channels_last_strides_exact_match ||
285
+ get_channels_last_strides_3d(sizes()) == strides()) {
286
+ return at::MemoryFormat::ChannelsLast3d;
287
+ }
288
+ }
289
+ }
290
+ return at::MemoryFormat::Contiguous;
291
+ }
292
+
293
+ // Total bytes consumed by the "view" of elements of the array. Does not
294
+ // include size of metadata. The number reported here does not necessarily
295
+ // correspond to the true physical memory consumed by a tensor; instead,
296
+ // it reports the memory the tensor would take *if* it were contiguous.
297
+ // Defined to be numel() * itemsize()
298
+ size_t nbytes() const {
299
+ TORCH_CHECK(layout () != at::kSparse,
300
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
301
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
302
+ "equivalent dense tensor, multiply numel() by element_size()");
303
+ return impl_->numel() * impl_->itemsize();
304
+ }
305
+
306
+ c10::SymInt sym_nbytes() const {
307
+ TORCH_CHECK(layout () != at::kSparse,
308
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
309
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
310
+ "equivalent dense tensor, multiply numel() by element_size()");
311
+ return impl_->sym_numel() * impl_->itemsize();
312
+ }
313
+
314
+ int64_t numel() const {
315
+ return impl_->numel();
316
+ }
317
+
318
+ c10::SymInt sym_numel() const {
319
+ return impl_->sym_numel();
320
+ }
321
+
322
+ c10::SymInt sym_storage_offset() const {
323
+ return impl_->sym_storage_offset();
324
+ }
325
+
326
+ // Length of one array element in bytes. This is the traditional
327
+ // Numpy naming.
328
+ size_t itemsize() const {
329
+ return impl_->itemsize();
330
+ }
331
+
332
+ // Same as itemsize(). This is the PyTorch naming.
333
+ int64_t element_size() const {
334
+ return static_cast<int64_t>(impl_->itemsize());
335
+ }
336
+
337
+ DispatchKeySet key_set() const {
338
+ return impl_->key_set();
339
+ }
340
+ ScalarType scalar_type() const {
341
+ return typeMetaToScalarType(impl_->dtype());
342
+ }
343
+ bool has_storage() const {
344
+ return defined() && impl_->has_storage();
345
+ }
346
+ const Storage& storage() const {
347
+ return impl_->storage();
348
+ }
349
+ bool is_alias_of(const at::TensorBase& other) const{
350
+ return impl_->storage().is_alias_of(other.storage());
351
+ }
352
+
353
+ // Move the storage backend to shm based
354
+ // to enable memory sharing across processes.
355
+ //
356
+ // NB1: the ideal behavior of this API still requires further discussion
357
+ // but for now we are inclined to keep it consistent with existing THP behavior
358
+ // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
359
+ // so we don't assert on anything here and rely on caller knowing
360
+ // what it's doing.
361
+ //
362
+ // NB2: this currently provides Linux fd based shm support only
363
+ // to simplify the storage lifetime management logic in ATen
364
+ // and similarly for now we are not adding support for file system based
365
+ // shm support like in THP due to additional GC manager support needed
366
+ // to prevent leaks.
367
+ // As such, calling this from non supported systems (e.g. Windows) would fail.
368
+ void share_memory_() {
369
+ at::share_memory_(*this);
370
+ }
371
+
372
+ inline bool _is_zerotensor() const {
373
+ return impl_->_is_zerotensor();
374
+ }
375
+
376
+ inline void _set_zero(bool zero) const {
377
+ impl_->_set_zero(zero);
378
+ }
379
+
380
+ inline bool is_conj() const {
381
+ return impl_->is_conj();
382
+ }
383
+
384
+ // sets the conjugate bit of a tensor.
385
+ // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
386
+ // that's what you want. Changing this might lead to incorrect behavior since conjugation is
387
+ // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
388
+ inline void _set_conj(bool conjugate) const {
389
+ impl_->_set_conj(conjugate);
390
+ }
391
+
392
+ inline bool is_neg() const {
393
+ return impl_->is_neg();
394
+ }
395
+
396
+ // sets the negative bit of a tensor.
397
+ // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
398
+ // that's what you want. Changing this might lead to incorrect behavior since we rely on this
399
+ // bit to determine if a negation needs to be materialized.
400
+ inline void _set_neg(bool negative) const {
401
+ impl_->_set_neg(negative);
402
+ }
403
+
404
+ /// Returns a `Tensor`'s layout.
405
+ Layout layout() const {
406
+ return impl_->layout();
407
+ }
408
+
409
+ /// Returns a `Tensor`'s dtype (`TypeMeta`).
410
+ caffe2::TypeMeta dtype() const {
411
+ return impl_->dtype();
412
+ }
413
+
414
+ /// Returns a `Tensor`'s device.
415
+ inline Device device() const {
416
+ return impl_->device();
417
+ }
418
+
419
+ /// Returns a `Tensor`'s device index.
420
+ DeviceIndex get_device() const {
421
+ // NB: this is not a native function to avoid dispatching overhead.
422
+ return impl_->get_device();
423
+ }
424
+
425
+ /// Returns if a `Tensor` has CPU backend.
426
+ bool is_cpu() const {
427
+ // NB: this is not a native function to avoid dispatching overhead.
428
+ return impl_->is_cpu();
429
+ }
430
+
431
+ /// Returns if a `Tensor` has CUDA backend.
432
+ bool is_cuda() const {
433
+ // NB: this is not a native function to avoid dispatching overhead.
434
+ return impl_->is_cuda();
435
+ }
436
+
437
+ /// Returns if a `Tensor` has IPU backend.
438
+ bool is_ipu() const {
439
+ // NB: this is not a native function to avoid dispatching overhead.
440
+ return impl_->is_ipu();
441
+ }
442
+
443
+ /// Returns if a `Tensor` has XPU backend.
444
+ bool is_xpu() const {
445
+ // NB: this is not a native function to avoid dispatching overhead.
446
+ return impl_->is_xpu();
447
+ }
448
+
449
+ /// Returns if a `Tensor` has XLA backend.
450
+ bool is_xla() const {
451
+ return impl_->is_xla();
452
+ }
453
+
454
+ /// Returns if a `Tensor` has MTIA backend.
455
+ bool is_mtia() const {
456
+ return impl_->is_mtia();
457
+ }
458
+
459
+ /// Returns if a `Tensor` has HPU backend.
460
+ bool is_hpu() const {
461
+ return impl_->is_hpu();
462
+ }
463
+
464
+ /// Returns if a `Tensor` has Lazy backend.
465
+ bool is_lazy() const {
466
+ return impl_->is_lazy();
467
+ }
468
+
469
+ /// Returns if a `Tensor` has HIP backend.
470
+ bool is_hip() const {
471
+ // NB: this is not a native function to avoid dispatching overhead.
472
+ return impl_->is_hip();
473
+ }
474
+
475
+ /// Returns if a `Tensor` has VE backend.
476
+ bool is_ve() const {
477
+ // NB: this is not a native function to avoid dispatching overhead.
478
+ return impl_->is_ve();
479
+ }
480
+
481
+ /// Returns if a `Tensor` has PrivateUse1 backend.
482
+ bool is_privateuseone() const {
483
+ // NB: this is not a native function to avoid dispatching overhead.
484
+ return impl_->is_privateuseone();
485
+ }
486
+
487
+ /// Returns if a `Tensor` has sparse backend.
488
+ bool is_sparse() const {
489
+ // NB: this is not a native function to avoid dispatching overhead.
490
+ return impl_->is_sparse();
491
+ }
492
+
493
+ /// Returns is a `Tensor` has a sparse CSR backend.
494
+ bool is_sparse_csr() const {
495
+ // NB: this is not a native function to avoid dispatching overhead.
496
+ return impl_->is_sparse_csr();
497
+ }
498
+
499
+ /// Returns if a `Tensor` is mkldnn tensor.
500
+ bool is_mkldnn() const {
501
+ // NB: this is not a native function to avoid dispatching overhead.
502
+ return impl_->is_mkldnn();
503
+ }
504
+
505
+ /// Returns if a `Tensor` is mps tensor.
506
+ bool is_mps() const {
507
+ // NB: this is not a native function to avoid dispatching overhead.
508
+ return impl_->is_mps();
509
+ }
510
+
511
+ /// Returns if a `Tensor` is maia tensor.
512
+ bool is_maia() const {
513
+ // NB: this is not a native function to avoid dispatching overhead.
514
+ return impl_->is_maia();
515
+ }
516
+
517
+ /// Returns if a `Tensor` is vulkan tensor.
518
+ bool is_vulkan() const {
519
+ // NB: this is not a native function to avoid dispatching overhead.
520
+ return impl_->is_vulkan();
521
+ }
522
+
523
+ /// Returns if a `Tensor` is metal tensor.
524
+ bool is_metal() const {
525
+ // NB: this is not a native function to avoid dispatching overhead.
526
+ return impl_->is_metal();
527
+ }
528
+
529
+ /// Returns if a `Tensor` has quantized backend.
530
+ bool is_quantized() const {
531
+ // NB: this is not a native function to avoid dispatching overhead.
532
+ return impl_->is_quantized();
533
+ }
534
+
535
+ /// Returns if a `Tensor` is a meta tensor. Meta tensors can
536
+ /// also have other designations.
537
+ bool is_meta() const {
538
+ return impl_->is_meta();
539
+ }
540
+
541
+ /// Returns if a `Tensor` is an inference tensor.
542
+ bool is_inference() const {
543
+ return impl_->is_inference();
544
+ }
545
+
546
+ // Returns if a `Tensor` is a NestedTensor.
547
+ bool is_nested() const {
548
+ return impl_->is_nested();
549
+ }
550
+
551
+ /// If a tensor is a quantized tensor, returns its quantizer
552
+ /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
553
+ QuantizerPtr quantizer() const;
554
+
555
+ /// Returns if a `Tensor` has any dimension names
556
+ bool has_names() const {
557
+ // If a user is using unnamed tensors, then we can short-circuit right here.
558
+ // Otherwise, impl::has_names attempts to retrieve names.
559
+ if (!impl_->has_named_tensor_meta()) {
560
+ return false;
561
+ }
562
+ return impl::has_names(unsafeGetTensorImpl());
563
+ }
564
+
565
+ /// Returns a `Tensor`'s dimension names data structure
566
+ const NamedTensorMeta* get_named_tensor_meta() const {
567
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
568
+ }
569
+
570
+ NamedTensorMeta* get_named_tensor_meta() {
571
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
572
+ }
573
+
574
+ /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
575
+ /// TensorOptions.h.
576
+ TensorOptions options() const {
577
+ return TensorOptions().dtype(dtype())
578
+ .device(device())
579
+ .layout(layout());
580
+ }
581
+
582
+ const void* const_data_ptr() const {
583
+ return this->unsafeGetTensorImpl()->data();
584
+ }
585
+
586
+ void* mutable_data_ptr() const {
587
+ return this->unsafeGetTensorImpl()->mutable_data();
588
+ }
589
+
590
+ // TODO(#97856) Make this return a const pointer. This currently
591
+ // returns a non-const pointer because of the large
592
+ // number of clients that we still want to audit before
593
+ // migrating to mutable_data_ptr().
594
+ void* data_ptr() const {
595
+ return mutable_data_ptr();
596
+ }
597
+
598
+ template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
599
+ const T* const_data_ptr() const;
600
+
601
+ template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
602
+ const std::remove_const_t<T>* const_data_ptr() const;
603
+
604
+ template <typename T>
605
+ T* mutable_data_ptr() const;
606
+
607
+ // Legacy interface during the migration to indicate that a callsite
608
+ // has not been audited for mutability.
609
+ //
610
+ // Do not add new uses of this, use const_data_ptr() if possible,
611
+ // mutable_data_ptr() otherwise.
612
+ //
613
+ // TODO(#97856) Make this return a const pointer. This is currently
614
+ // const because of the vast number of clients that
615
+ // rely on this.
616
+ template <typename T>
617
+ T* data_ptr() const;
618
+
619
+ // Purposely not defined here to avoid inlining
620
+ void print() const;
621
+
622
+ // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
623
+ // dimension.
624
+ template<typename T, size_t N>
625
+ TensorAccessor<T,N> accessor() const& {
626
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
627
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
628
+ T* ptr = nullptr;
629
+ if constexpr (std::is_const_v<T>) {
630
+ ptr = const_data_ptr<T>();
631
+ } else {
632
+ ptr = mutable_data_ptr<T>();
633
+ }
634
+ return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
635
+ }
636
+ template<typename T, size_t N>
637
+ TensorAccessor<T,N> accessor() && = delete;
638
+
639
+ // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
640
+ // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
641
+ // cast the data pointer to a __restrict__ pointer.
642
+ // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
643
+ // as an argument.
644
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
645
+ GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
646
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
647
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
648
+ T* ptr = nullptr;
649
+ if constexpr (std::is_const_v<T>) {
650
+ ptr = const_data_ptr<T>();
651
+ } else {
652
+ ptr = mutable_data_ptr<T>();
653
+ }
654
+ return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
655
+ }
656
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
657
+ GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
658
+
659
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
660
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
661
+ TORCH_CHECK(
662
+ impl_->numel() <=
663
+ static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
664
+ "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
665
+ return generic_packed_accessor<T,N,PtrTraits,int32_t>();
666
+ }
667
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
668
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
669
+
670
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
671
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
672
+ return generic_packed_accessor<T,N,PtrTraits,int64_t>();
673
+ }
674
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
675
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
676
+
677
+ // ~~~~~ Autograd API ~~~~~
678
+
679
+ /// \fn bool is_leaf() const;
680
+ ///
681
+ /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
682
+ ///
683
+ /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
684
+ /// created by the user. This means that they are not the result of an operation and so
685
+ /// `grad_fn()` is `nullptr`.
686
+ ///
687
+ /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
688
+ /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
689
+ ///
690
+ /// Example:
691
+ /// @code
692
+ /// auto a = torch::rand(10, torch::requires_grad());
693
+ /// std::cout << a.is_leaf() << std::endl; // prints `true`
694
+ ///
695
+ /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
696
+ /// std::cout << b.is_leaf() << std::endl; // prints `false`
697
+ /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
698
+ ///
699
+ /// auto c = torch::rand(10, torch::requires_grad()) + 2;
700
+ /// std::cout << c.is_leaf() << std::endl; // prints `false`
701
+ /// // c was created by the addition operation
702
+ ///
703
+ /// auto d = torch::rand(10).cuda();
704
+ /// std::cout << d.is_leaf() << std::endl; // prints `true`
705
+ /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
706
+ ///
707
+ /// auto e = torch::rand(10).cuda().requires_grad_();
708
+ /// std::cout << e.is_leaf() << std::endl; // prints `true`
709
+ /// // e requires gradients and has no operations creating it
710
+ ///
711
+ /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
712
+ /// std::cout << f.is_leaf() << std::endl; // prints `true`
713
+ /// // f requires grad, has no operation creating it
714
+ /// @endcode
715
+
716
+ /// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
717
+ ///
718
+ /// Computes the gradient of current tensor with respect to graph leaves.
719
+ ///
720
+ /// The graph is differentiated using the chain rule. If the tensor is
721
+ /// non-scalar (i.e. its data has more than one element) and requires
722
+ /// gradient, the function additionally requires specifying ``gradient``.
723
+ /// It should be a tensor of matching type and location, that contains
724
+ /// the gradient of the differentiated function w.r.t. this Tensor.
725
+ ///
726
+ /// This function accumulates gradients in the leaves - you might need to
727
+ /// zero them before calling it.
728
+ ///
729
+ /// \param gradient Gradient w.r.t. the
730
+ /// tensor. If it is a tensor, it will be automatically converted
731
+ /// to a Tensor that does not require grad unless ``create_graph`` is True.
732
+ /// None values can be specified for scalar Tensors or ones that
733
+ /// don't require grad. If a None value would be acceptable then
734
+ /// this argument is optional.
735
+ /// \param retain_graph If ``false``, the graph used to compute
736
+ /// the grads will be freed. Note that in nearly all cases setting
737
+ /// this option to True is not needed and often can be worked around
738
+ /// in a much more efficient way. Defaults to the value of
739
+ /// ``create_graph``.
740
+ /// \param create_graph If ``true``, graph of the derivative will
741
+ /// be constructed, allowing to compute higher order derivative
742
+ /// products. Defaults to ``false``.
743
+ /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
744
+ /// ``at::Tensor::grad``. All other Tensors will be ignored. If not
745
+ /// provided, the gradient is accumulated into all the leaf Tensors
746
+ /// that were used to compute the current tensor.
747
+ /// When inputs are provided and a given input is not a leaf,
748
+ /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
749
+ /// It is an implementation detail on which the user should not rely.
750
+ /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
751
+
752
+ /// \fn Tensor detach() const;
753
+ ///
754
+ /// Returns a new Tensor, detached from the current graph.
755
+ /// The result will never require gradient.
756
+
757
+ /// \fn Tensor & detach_() const;
758
+ ///
759
+ /// Detaches the Tensor from the graph that created it, making it a leaf.
760
+ /// Views cannot be detached in-place.
761
+
762
+ /// \fn void retain_grad() const;
763
+ ///
764
+ /// Enables this Tensor to have their :attr:`grad` populated during
765
+ /// :func:`backward`. This is a no-op for leaf tensors.
766
+
767
+ /// \fn bool retains_grad() const;
768
+ ///
769
+ /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
770
+ /// populated during :func:`backward`, ``false`` otherwise.
771
+
772
+ const TensorBase& set_requires_grad(bool requires_grad) const {
773
+ impl_->set_requires_grad(requires_grad);
774
+ return *this;
775
+ }
776
+ bool requires_grad() const {
777
+ return impl_->requires_grad();
778
+ }
779
+
780
+ // The Forward AD API functions below are low level and are not to be used by end
781
+ // users who should use the API provided in torch/csrc/autograd.h
782
+
783
+ /// This function returns the forward gradient for this Tensor at the given level.
784
+ const Tensor& _fw_grad(uint64_t level) const {
785
+ return impl_->_fw_grad(level, *this);
786
+ }
787
+
788
+ /// This function can be used to set the value of the forward grad.
789
+ /// Note that the given new_grad might not be used directly if it has different
790
+ /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
791
+ /// new_grad content will be copied into a new Tensor
792
+ void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
793
+ impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
794
+ }
795
+
796
+ /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
797
+ /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
798
+ /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
799
+ ///
800
+ /// One notable difference with the legacy `.data()` function is that changes to the
801
+ /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
802
+ /// will not update the original `Variable`, due to the fact that this function
803
+ /// shallow-copies the `Variable`'s underlying TensorImpl.
804
+ at::TensorBase tensor_data() const;
805
+
806
+ /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
807
+ /// in Python, which create a new `Variable` that shares the same storage and
808
+ /// tensor metadata with the original `Variable`, but with a completely new
809
+ /// autograd history.
810
+ ///
811
+ /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
812
+ /// storage / storage_offset) of a variable created from `var.variable_data()`, those
813
+ /// changes will not update the original variable `var`. In `.variable_data()`, we set
814
+ /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
815
+ /// in order to prevent users from changing metadata of `var.variable_data()`
816
+ /// and expecting the original variable `var` to also be updated.
817
+ at::TensorBase variable_data() const;
818
+
819
+ // Gradient Node and Edges
820
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
821
+
822
+ /// Gets the gradient function of the `Variable`. If this is a leaf variable,
823
+ /// the pointer returned will be null.
824
+ ///
825
+ /// For View Variables:
826
+ /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
827
+ /// re-create the grad_fn to express the up-to-date view relationship between
828
+ /// this and the base Variable.
829
+ const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
830
+
831
+ // Hooks
832
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
833
+
834
+ template <typename T>
835
+ using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
836
+ template <typename T>
837
+ using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
838
+
839
+ /// Registers a backward hook.
840
+ ///
841
+ /// The hook will be called every time a gradient with respect to the Tensor is computed.
842
+ /// The hook should have one of the following signature:
843
+ /// ```
844
+ /// hook(TensorBase grad) -> TensorBase
845
+ /// ```
846
+ /// ```
847
+ /// hook(TensorBase grad) -> void
848
+ /// ```
849
+ /// The hook should not modify its argument, but it can optionally return a new gradient
850
+ /// which will be used in place of `grad`.
851
+ ///
852
+ /// This function returns the index of the hook in the list which can be used to remove hook.
853
+ ///
854
+ /// Example:
855
+ /// @code
856
+ /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
857
+ /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
858
+ /// v.backward(torch::tensor({1., 2., 3.}));
859
+ /// // This prints:
860
+ /// // ```
861
+ /// // 2
862
+ /// // 4
863
+ /// // 6
864
+ /// // [ CPUFloatType{3} ]
865
+ /// // ```
866
+ /// std::cout << v.grad() << std::endl;
867
+ /// v.remove_hook(h); // removes the hook
868
+ /// @endcode
869
+ template <typename T>
870
+ hook_return_void_t<T> register_hook(T&& hook) const;
871
+ template <typename T>
872
+ hook_return_var_t<T> register_hook(T&& hook) const;
873
+
874
+ protected:
875
+ unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
876
+
877
+ public:
878
+
879
+ /// Remove hook at given position
880
+ void remove_hook(unsigned pos) const;
881
+
882
+ // Variable methods
883
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
884
+
885
+ bool is_leaf() const;
886
+
887
+ int64_t output_nr() const;
888
+
889
+ void set_data(const TensorBase & new_data) const;
890
+
891
+ TensorBase data() const;
892
+
893
+ int64_t _version() const;
894
+
895
+ void retain_grad() const;
896
+
897
+ bool retains_grad() const;
898
+
899
+ const TensorBase& requires_grad_(bool _requires_grad=true) const;
900
+
901
+ // View Variables
902
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
903
+
904
+ /// Returns true if this `Variable` is a view of another `Variable`.
905
+ bool is_view() const;
906
+
907
+ /// Returns the `Variable` that this `Variable` is a view of. If this
908
+ /// `Variable` is not a view, throw a `std::runtime_error`.
909
+ const TensorBase& _base() const;
910
+
911
+ // Miscellaneous
912
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
913
+
914
+ const std::string& name() const;
915
+
916
+ protected:
917
+ void enforce_invariants();
918
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
919
+
920
+ private:
921
+ TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
922
+ };
923
+
924
+ inline DeviceIndex get_device(const TensorBase& self) {
925
+ return self.get_device();
926
+ }
927
+
928
+ template <typename T>
929
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
930
+ // Return the grad argument in case of a hook with void return type to have an
931
+ // std::function with Tensor return type
932
+ static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
933
+ "Expected hook to return void");
934
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
935
+ fn(grad);
936
+ return TensorBase();
937
+ });
938
+ }
939
+
940
+ template <typename T>
941
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
942
+ return _register_hook(std::forward<T>(hook));
943
+ }
944
+
945
+ namespace detail {
946
+ // Helper creator for Tensor class which doesn't requires the users to pass
947
+ // in an intrusive_ptr instead it just converts the argument passed to
948
+ // requested intrusive_ptr type.
949
+ template <typename T, typename... Args>
950
+ TensorBase make_tensor_base(Args&&... args) {
951
+ return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
952
+ }
953
+
954
+ } // namespace detail
955
+
956
+ inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
957
+ return legacyExtractDispatchKey(t.key_set());
958
+ }
959
+
960
+ } // namespace at
961
+
962
+ namespace c10 {
963
+ template <>
964
+ struct MaybeOwnedTraits<at::TensorBase> {
965
+ using owned_type = at::TensorBase;
966
+ using borrow_type = at::TensorBase;
967
+
968
+ static borrow_type createBorrow(const owned_type& from) {
969
+ // NOTE: this can be implemented without the special
970
+ // unsafe_borrow_t Tensor constructor as
971
+ //
972
+ // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
973
+ //
974
+ // but that hurts inlining due to the nullptr check in the
975
+ // Tensor(c10::intrusive_ptr<...>) constructor. We already know
976
+ // that from.impl_ isn't null because from is a valid Tensor, so
977
+ // we needn't do the check again. (using __builtin_assume can
978
+ // avoid this, but wouldn't be portable to MSVC.)
979
+ return borrow_type(borrow_type::unsafe_borrow_t{}, from);
980
+ }
981
+
982
+ static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
983
+ lhs.unsafeReleaseTensorImpl();
984
+ // See above note: this can be implemented with public API
985
+ // similarly to createBorrow(), but that would hurt inlining.
986
+ lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
987
+ }
988
+
989
+ static void destroyBorrow(borrow_type& toDestroy) {
990
+ toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
991
+ }
992
+
993
+ static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
994
+ return borrow;
995
+ }
996
+
997
+ static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
998
+ return &borrow;
999
+ }
1000
+
1001
+ static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
1002
+ return true;
1003
+ }
1004
+ };
1005
+
1006
+ template <>
1007
+ struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
1008
+ } // namespace c10
1009
+
1010
+ namespace at {
1011
+
1012
+ inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
1013
+ const std::optional<TensorBase>& opt) {
1014
+ return opt.has_value()
1015
+ ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
1016
+ : c10::MaybeOwned<TensorBase>::owned(std::in_place);
1017
+ }
1018
+
1019
+ inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
1020
+ if (is_contiguous(memory_format)) {
1021
+ return c10::MaybeOwned<TensorBase>::borrowed(*this);
1022
+ } else {
1023
+ return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
1024
+ }
1025
+ }
1026
+
1027
+ namespace symint {
1028
+
1029
+ template <typename T>
1030
+ using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
1031
+ template <typename T>
1032
+ using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
1033
+
1034
+ template <typename T, typename = enable_if_symint<T>>
1035
+ c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
1036
+ template <typename T, typename = enable_if_int<T>>
1037
+ IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
1038
+
1039
+ template <typename T, typename = enable_if_symint<T>>
1040
+ c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
1041
+ template <typename T, typename = enable_if_int<T>>
1042
+ int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
1043
+
1044
+ template <typename T, typename = enable_if_symint<T>>
1045
+ c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
1046
+ template <typename T, typename = enable_if_int<T>>
1047
+ IntArrayRef strides(const TensorBase& t) { return t.strides(); }
1048
+
1049
+ template <typename T, typename = enable_if_symint<T>>
1050
+ c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
1051
+ template <typename T, typename = enable_if_int<T>>
1052
+ int64_t numel(const TensorBase& t) { return t.numel(); }
1053
+
1054
+ } // namespace symint
1055
+
1056
+ } // namespace at
.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/dispatch/Dispatcher.h>
4
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <torch/library.h>
7
+ #include <optional>
8
+
9
+ namespace at::impl {
10
+
11
+ TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
12
+ TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
13
+ TORCH_API bool tensorlist_has_dispatch(
14
+ const c10::List<std::optional<at::Tensor>>& li);
15
+ using c10::impl::dispatch_mode_enabled;
16
+
17
+ } // namespace at::impl
.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/NumericUtils.h>
2
+ #include <c10/macros/Macros.h>
3
+ #include <c10/util/Half.h>
4
+ #include <c10/util/BFloat16.h>
5
+ #include <c10/util/MathConstants.h>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <cassert>
9
+ #include <limits>
10
+ #include <type_traits>
11
+
12
+ namespace at {
13
+
14
+ // Using DistAccumType in accumulate types for distributions.
15
+ // Note: Ideally we'd be using ATen/AccumulateType.h but looks
16
+ // like the there is some inconsistency in how accumulate types
17
+ // are mapped currently, e.g. for the cpu side, float is mapped
18
+ // to double.
19
+ template <typename T>
20
+ struct DistAccumType { };
21
+
22
+ #if defined(__CUDACC__) || defined(__HIPCC__)
23
+ template <> struct DistAccumType<half> { using type = float; };
24
+ #endif
25
+ template <> struct DistAccumType<BFloat16> { using type = float; };
26
+ template <> struct DistAccumType<Half> { using type = float; };
27
+ template <> struct DistAccumType<float> { using type = float; };
28
+ template <> struct DistAccumType<double> { using type = double; };
29
+
30
+ template <typename T>
31
+ using dist_acctype = typename DistAccumType<T>::type;
32
+
33
+ namespace transformation {
34
+
35
+ /**
36
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
37
+ * `range` is `to - from`
38
+ * `base` is `from`
39
+ */
40
+ template <typename T, typename V>
41
+ C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) {
42
+ return static_cast<T>(static_cast<int64_t>((val % range) + base));
43
+ }
44
+
45
+ /**
46
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
47
+ */
48
+ template <typename T, typename V>
49
+ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
50
+ return static_cast<T>(static_cast<int64_t>(val));
51
+ }
52
+
53
+ /**
54
+ * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
55
+ * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
56
+ * in this overloaded version
57
+ */
58
+ template <typename T, typename V>
59
+ C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
60
+ if constexpr (std::is_same_v<T, bool>) {
61
+ return static_cast<bool>(val & 1);
62
+ } else if constexpr (std::is_same_v<T, int64_t>) {
63
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
64
+ } else if constexpr (std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>) {
65
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
66
+ } else if constexpr (std::is_integral_v<T>) {
67
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
68
+ } else {
69
+ assert(false);
70
+ return 0;
71
+ }
72
+ }
73
+
74
+ /**
75
+ * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
76
+ * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
77
+ */
78
+ template<typename T, typename V>
79
+ C10_HOST_DEVICE inline std::enable_if_t<std::is_floating_point_v<T>, T>uniform_int(V val) {
80
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
81
+ }
82
+
83
+ template <typename T, typename V>
84
+ C10_HOST_DEVICE inline dist_acctype<T> uniform_real(V val, T from, T to) {
85
+ constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
86
+ constexpr auto DIVISOR = static_cast<dist_acctype<T>>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
87
+ dist_acctype<T> x = (val & MASK) * DIVISOR;
88
+ return (x * (to - from) + from);
89
+ }
90
+
91
+ /**
92
+ * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to
93
+ * normally distributed with `mean` and standard deviation `std`.
94
+ */
95
+ template <typename T>
96
+ C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
97
+ return val * std + mean;
98
+ }
99
+
100
+ /**
101
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
102
+ * Cauchy distribution with location parameter `median` and scale parameter `sigma`.
103
+ */
104
+ template <typename T>
105
+ C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
106
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
107
+ // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps),
108
+ // thus we clip those values.
109
+ constexpr T eps = std::numeric_limits<T>::epsilon();
110
+ constexpr T one_minus_eps = 1 - eps;
111
+ constexpr T zero_plus_eps = 0 + eps;
112
+ val = (val > one_minus_eps ? one_minus_eps : val);
113
+ val = (val < zero_plus_eps ? zero_plus_eps : val);
114
+ return median + sigma * at::tan(c10::pi<T> * (val - static_cast<T>(0.5)));
115
+ }
116
+
117
+ template <>
118
+ C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
119
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
120
+ return median + sigma * at::tan(c10::pi<double> * (val - static_cast<double>(0.5)));
121
+ }
122
+
123
+ /**
124
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
125
+ * exponentially distributed with `lambda` parameter of the distribution.
126
+ */
127
+ template <typename T>
128
+ C10_HOST_DEVICE inline T exponential(T val, T lambda) {
129
+ // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
130
+ // Different implementations for CUDA and CPU to preserve original logic
131
+ // TODO: must be investigated and unified!!!
132
+ // https://github.com/pytorch/pytorch/issues/38662
133
+ #if defined(__CUDACC__) || defined(__HIPCC__)
134
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
135
+ // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
136
+ // we need log to be not 0, and not underflow when converted to half
137
+ // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
138
+ auto log = val >= static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2
139
+ ? -std::numeric_limits<T>::epsilon() / 2
140
+ : at::log(val);
141
+ return static_cast<T>(-1.0) / lambda * log;
142
+ #else
143
+ return static_cast<T>(-1.0) / lambda * at::log1p(-val);
144
+ #endif
145
+ }
146
+
147
+ /**
148
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
149
+ * geometrically distributed with success probability `p`.
150
+ */
151
+ template <typename T>
152
+ C10_HOST_DEVICE inline T geometric(T val, T p) {
153
+ // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions
154
+ return static_cast<T>(::ceil(at::log(val) / at::log1p(-p)));
155
+ }
156
+
157
+ /**
158
+ * Transforms normally distributed `val` to log-normally distributed.
159
+ */
160
+ template <typename T>
161
+ C10_HOST_DEVICE inline T log_normal(T val) {
162
+ // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles
163
+ return at::exp(val);
164
+ }
165
+
166
+ /**
167
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
168
+ * bernoulli distributed with success probability `p`.
169
+ */
170
+ template <typename T>
171
+ C10_HOST_DEVICE inline T bernoulli(T val, T p) {
172
+ return val < p;
173
+ }
174
+
175
+ }} // namespace at::transformation
.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/UndefinedTensorImpl.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+
6
+ inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
7
+ auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
8
+ if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
9
+ c10::raw::intrusive_ptr::incref(tensor_impl.get());
10
+ }
11
+ return Tensor(std::move(tensor_impl));
12
+ }
13
+
14
+ inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
15
+ if (retain && th_pointer) {
16
+ c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
17
+ }
18
+ return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
19
+ }
20
+
21
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <c10/macros/Export.h>
5
+
6
+ // A little explanation about why this file exists at all. We have
7
+ // a few methods on Tensor class which require access to reified access to
8
+ // AutogradMeta. In open source, this isn't a big deal: we just access
9
+ // torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
10
+ // we can put the definitions inline. This is because everything gets balled
11
+ // into a single dynamic library in the end.
12
+ //
13
+ // However, inside our Facebook internal version of our build system, we
14
+ // have a split between aten and torch/csrc. So we cannot simply just
15
+ // cross this boundary. "Now wait," you might say, "Why don't we just
16
+ // merge the libraries inside Facebook". Well, the problem is that there
17
+ // are some downstream applications which are at binary size limit, and
18
+ // incorporating all of the extra code from libtorch would push them
19
+ // over (admarket/adreview/service:adreviewservice, see also
20
+ // https://github.com/pytorch/pytorch/pull/29299) So if you want to do that,
21
+ // we have to fix all of the services like this.
22
+ //
23
+ // I didn't want to block eliminating Tensor-Variable on this work, so I
24
+ // had to introduce another dynamic dispatch to get to the variable
25
+ // implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
26
+ //
27
+ // I also considered using our existing dynamic dispatch mechanism, c10
28
+ // dispatcher, to do this. However, (1) some of the functions on Tensor
29
+ // have weird signatures that are not supported by autograd, and (2)
30
+ // see this bug https://github.com/pytorch/pytorch/issues/30102
31
+
32
+ namespace torch::autograd {
33
+
34
+ struct Node;
35
+
36
+ } // namespace torch::autograd
37
+
38
+ namespace at::impl {
39
+
40
+ struct TORCH_API VariableHooksInterface {
41
+ virtual ~VariableHooksInterface() = default;
42
+ virtual TensorBase tensor_data(const TensorBase&) const = 0;
43
+ virtual TensorBase variable_data(const TensorBase&) const = 0;
44
+ virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(
45
+ const TensorBase&) const = 0;
46
+ virtual unsigned _register_hook(
47
+ const TensorBase&,
48
+ std::function<TensorBase(const TensorBase&)> hook) const = 0;
49
+ virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
50
+ virtual bool is_view(const TensorBase&) const = 0;
51
+ virtual const TensorBase& base(const TensorBase&) const = 0;
52
+ virtual const std::string& name(const TensorBase&) const = 0;
53
+ virtual bool is_leaf(const TensorBase&) const = 0;
54
+ virtual int64_t output_nr(const TensorBase&) const = 0;
55
+ virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
56
+ virtual TensorBase data(const TensorBase&) const = 0;
57
+ virtual int64_t _version(const TensorBase&) const = 0;
58
+ virtual void retain_grad(const TensorBase&) const = 0;
59
+ virtual bool retains_grad(const TensorBase&) const = 0;
60
+ virtual void _backward(
61
+ const Tensor&,
62
+ TensorList,
63
+ const std::optional<Tensor>&,
64
+ std::optional<bool>,
65
+ bool) const = 0;
66
+ virtual void requires_grad_(const TensorBase&, bool) const = 0;
67
+ virtual void basic_autograd_not_implemented_fallback(
68
+ const c10::OperatorHandle& op,
69
+ c10::DispatchKeySet dispatch_keys,
70
+ torch::jit::Stack* stack) const = 0;
71
+ };
72
+
73
+ TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
74
+ TORCH_API VariableHooksInterface* GetVariableHooks();
75
+ TORCH_API bool HasVariableHooks();
76
+
77
+ struct TORCH_API VariableHooksRegisterer {
78
+ explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
79
+ SetVariableHooks(hooks);
80
+ }
81
+ };
82
+
83
+ } // namespace at::impl
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <utility>
4
+
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <ATen/core/List.h>
7
+
8
+ namespace at {
9
+
10
+ // This class allows you to write variadic functions which
11
+ // call a (possibly overloaded) function on each argument,
12
+ // in order. This is most commonly used in autogenerated code,
13
+ // where it is convenient to have a function that can uniformly
14
+ // take arguments of different types. If your arguments
15
+ // are homogenous consider using a std::initializer_list instead.
16
+ //
17
+ // For examples of this in use, see torch/csrc/utils/variadic.h
18
+ template <typename F>
19
+ struct IterArgs {
20
+ template <typename... Args>
21
+ inline F& apply() {
22
+ return self();
23
+ }
24
+
25
+ // NB: Use perfect forwarding here, otherwise we'll make value
26
+ // copies of all arguments!
27
+ template <typename T, typename... Args>
28
+ inline F& apply(T&& arg, Args&&... args) {
29
+ self()(std::forward<T>(arg));
30
+ if (self().short_circuit()) {
31
+ return self();
32
+ } else {
33
+ return apply(std::forward<Args>(args)...);
34
+ }
35
+ }
36
+
37
+ // Here are some handy overloads which provide sensible
38
+ // defaults for container-like structures that one might
39
+ // be interested in recursing into. You can enable them
40
+ // by adding:
41
+ //
42
+ // using IterArgs<YourStructName>::operator()
43
+ //
44
+ // to your struct. These are not enabled by default because
45
+ // you may be able to process these structures more efficiently
46
+ // than handling them one-by-one.
47
+
48
+ template <typename T>
49
+ void operator()(c10::IListRef<T> args) {
50
+ for (const auto& arg : args) {
51
+ self()(arg);
52
+ if (self().short_circuit())
53
+ return;
54
+ }
55
+ }
56
+
57
+ template <typename T>
58
+ void operator()(at::ArrayRef<T> args) {
59
+ for (const auto& arg : args) {
60
+ self()(arg);
61
+ if (self().short_circuit())
62
+ return;
63
+ }
64
+ }
65
+
66
+ template <typename T>
67
+ void operator()(const torch::List<T>& args) {
68
+ for (const auto& arg : args) {
69
+ self()(arg);
70
+ if (self().short_circuit())
71
+ return;
72
+ }
73
+ }
74
+
75
+ // NB: we need to specify std::vector manually as C++ won't
76
+ // do an implicit conversion to make a template deduction go through.
77
+ template <typename T>
78
+ void operator()(const std::vector<T>& args) {
79
+ self()(at::ArrayRef<T>{args});
80
+ }
81
+
82
+ constexpr bool short_circuit() const {
83
+ return false;
84
+ }
85
+
86
+ private:
87
+ inline F& self() {
88
+ return *static_cast<F*>(this);
89
+ }
90
+ };
91
+
92
+ } // namespace torch
.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ostream>
3
+ #include <sstream>
4
+ #include <unordered_map>
5
+
6
+ #include <c10/core/impl/LocalDispatchKeySet.h>
7
+
8
+ namespace at::vitals {
9
+
10
+ TORCH_API bool torchVitalEnabled();
11
+
12
+ struct TORCH_API TorchVitalAttr {
13
+ // always initialized to empty
14
+ std::string value;
15
+ template <typename T>
16
+ TorchVitalAttr& operator<<(const T& t) {
17
+ if (torchVitalEnabled()) {
18
+ std::stringstream ss;
19
+ ss << t;
20
+ value += ss.str();
21
+ }
22
+ return *this;
23
+ }
24
+
25
+ template <typename T>
26
+ void write(const T& t, bool force) {
27
+ if (force || torchVitalEnabled()) {
28
+ std::stringstream ss;
29
+ ss << t;
30
+ value = ss.str();
31
+ }
32
+ }
33
+ };
34
+
35
+ struct TORCH_API TorchVital {
36
+ std::string name;
37
+ std::unordered_map<std::string, TorchVitalAttr> attrs;
38
+
39
+ explicit TorchVital(std::string n) : name(std::move(n)) {}
40
+ TorchVital(const TorchVital&) = default;
41
+ TorchVital(TorchVital&&) = default;
42
+ TorchVital& operator=(const TorchVital&) = default;
43
+ TorchVital& operator=(TorchVital&&) = default;
44
+ TorchVital() = delete;
45
+
46
+ TorchVitalAttr& create(const std::string& attr);
47
+ TorchVitalAttr& create(const std::string& attr, bool force);
48
+ friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
49
+
50
+ ~TorchVital();
51
+ };
52
+
53
+ std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
54
+
55
+ // A way to access vitals by string names instead of by global reference.
56
+ // This enables access to vitals from the PythonAPI.
57
+ class TORCH_API APIVitals {
58
+ public:
59
+ bool vitals_enabled;
60
+
61
+ // Set any vital sign that was added to the map.
62
+ bool setVital(
63
+ const std::string& vital_name,
64
+ const std::string& attr_name,
65
+ const std::string& value,
66
+ bool force = false);
67
+ std::string readVitals();
68
+
69
+ APIVitals();
70
+
71
+ // Ensure this stays a singleton
72
+ APIVitals(APIVitals const& other) = delete;
73
+ APIVitals(APIVitals&& other) = delete;
74
+ APIVitals& operator=(const APIVitals&) = delete;
75
+ APIVitals& operator=(APIVitals&&) = delete;
76
+ ~APIVitals() = default;
77
+
78
+ private:
79
+ std::unordered_map<std::string, TorchVital> name_map_;
80
+ };
81
+
82
+ extern TORCH_API APIVitals VitalsAPI;
83
+
84
+ } // namespace at::vitals
85
+
86
+ #define TORCH_VITAL_DECLARE(name) \
87
+ TORCH_API at::vitals::TorchVital TorchVital_##name;
88
+
89
+ #define TORCH_VITAL_DEFINE(name) \
90
+ TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
91
+
92
+ #define TORCH_VITAL_BASE(name) TorchVital_##name
93
+
94
+ #define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <set>
3
+ #include <string>
4
+ #include <unordered_set>
5
+ #include <vector>
6
+ #include <ATen/core/symbol.h>
7
+ #include <c10/util/Exception.h>
8
+ #include <c10/util/hash.h>
9
+
10
+ namespace c10 {
11
+ /**
12
+ * class AliasInfo
13
+ *
14
+ * Data structure to hold aliasing information for an `Argument`. They can be
15
+ * nested to represent aliasing information on contained types.
16
+ *
17
+ * There is a `beforeSet` which describes the aliasing information before the
18
+ * operator executes, and an `afterSet` that describes aliasing info
19
+ * after execution.
20
+ */
21
+ class AliasInfo {
22
+ public:
23
+ AliasInfo() = default;
24
+ AliasInfo(bool is_write, const std::set<std::string>& before_qual_strings, const std::set<std::string>& after_qual_strings) : isWrite_(is_write) {
25
+ for (const auto& s: before_qual_strings) {
26
+ beforeSets_.insert(Symbol::fromQualString(s));
27
+ }
28
+ for (const auto& s : after_qual_strings) {
29
+ afterSets_.insert(Symbol::fromQualString(s));
30
+ }
31
+ }
32
+ // Symbol for the set that can alias anything
33
+ static Symbol wildcardSet() {
34
+ static const Symbol wc = Symbol::fromQualString("alias::*");
35
+ return wc;
36
+ }
37
+
38
+ void setIsWrite(bool isWrite) {
39
+ isWrite_ = isWrite;
40
+ }
41
+
42
+ bool isWrite() const {
43
+ return isWrite_;
44
+ }
45
+
46
+ void addBeforeSet(Symbol aliasSet) {
47
+ beforeSets_.insert(aliasSet);
48
+ }
49
+
50
+ void addAfterSet(Symbol aliasSet) {
51
+ afterSets_.insert(aliasSet);
52
+ }
53
+
54
+ const std::unordered_set<Symbol>& beforeSets() const {
55
+ return beforeSets_;
56
+ }
57
+
58
+ const std::unordered_set<Symbol>& afterSets() const {
59
+ return afterSets_;
60
+ }
61
+
62
+ Symbol beforeSet() const {
63
+ AT_ASSERT(beforeSets_.size() == 1);
64
+ return *beforeSets_.begin();
65
+ }
66
+
67
+ bool isWildcardBefore() const {
68
+ return beforeSets_.count(wildcardSet()) != 0;
69
+ }
70
+
71
+ bool isWildcardAfter() const {
72
+ return afterSets_.count(wildcardSet()) != 0;
73
+ }
74
+
75
+ // the alias info for the contained types of the type
76
+ // e.g. if this is an annotation on List[T], `sets` refers to
77
+ // the alias sets that the list may be in
78
+ // while containedTypes()[0] refers to the sets that members of the list
79
+ // may be in
80
+ void addContainedType(AliasInfo aliasInfo) {
81
+ containedTypes_.push_back(std::move(aliasInfo));
82
+ }
83
+ const std::vector<AliasInfo>& containedTypes() const {
84
+ return containedTypes_;
85
+ }
86
+
87
+ private:
88
+ std::unordered_set<Symbol> beforeSets_;
89
+ std::unordered_set<Symbol> afterSets_;
90
+ std::vector<AliasInfo> containedTypes_;
91
+ bool isWrite_ = false;
92
+ };
93
+
94
+ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
95
+ return lhs.isWrite() == rhs.isWrite()
96
+ && lhs.beforeSets() == rhs.beforeSets()
97
+ && lhs.afterSets() == rhs.afterSets()
98
+ && lhs.containedTypes() == rhs.containedTypes();
99
+ }
100
+
101
+ // this does match the way things are represented in the schema
102
+ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
103
+ out << "(";
104
+ bool first = true;
105
+ for (const auto& set : aliasInfo.beforeSets()) {
106
+ if (first) {
107
+ first = false;
108
+ } else {
109
+ out << "|";
110
+ }
111
+ out << set.toUnqualString();
112
+ }
113
+ if (aliasInfo.isWrite()) {
114
+ out << "!";
115
+ }
116
+ if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
117
+ out << " -> ";
118
+ first = true;
119
+ for (const auto& set : aliasInfo.afterSets()) {
120
+ if (first) {
121
+ first = false;
122
+ } else {
123
+ out << "|";
124
+ }
125
+ out << set.toUnqualString();
126
+ }
127
+ }
128
+ out << ")";
129
+ return out;
130
+ }
131
+ } // namespace c10
132
+
133
+ namespace std {
134
+ template <>
135
+ struct hash<c10::AliasInfo> {
136
+ size_t operator()(const c10::AliasInfo& aliasInfo) const {
137
+ auto hash = std::hash<bool>()(aliasInfo.isWrite());
138
+
139
+ // NOTE: for unordered_set hashes, we couldn't use hash_combine
140
+ // because hash_combine is order dependent. Instead, we choose to
141
+ // use XOR as the combining function as XOR is commutative.
142
+ size_t before_set_hash_seed = 0;
143
+ for (auto &e: aliasInfo.beforeSets()) {
144
+ auto symbol_hash = std::hash<c10::Symbol>()(e);
145
+ before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
146
+ }
147
+ size_t after_set_hash_seed = 0;
148
+ for (auto &e: aliasInfo.afterSets()) {
149
+ auto symbol_hash = std::hash<c10::Symbol>()(e);
150
+ after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
151
+ }
152
+
153
+ hash = c10::hash_combine(hash, before_set_hash_seed);
154
+ hash = c10::hash_combine(hash, after_set_hash_seed);
155
+ for (auto &e: aliasInfo.containedTypes()) {
156
+ auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
157
+ hash = c10::hash_combine(hash, contained_type_hash);
158
+ }
159
+ return hash;
160
+ }
161
+ };
162
+ }
.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h ADDED
@@ -0,0 +1,2294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from aten_interned_strings.h
4
+
5
+ #if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if including <ATen/core/symbol.h> for \
9
+ the c10::Symbol class would be sufficient, or if your change would be \
10
+ better placed in another file.
11
+ #endif
12
+
13
+ // ATen symbols correspond exactly to operators defined in ATen. Every
14
+ // symbol here corresponds exactly to an ATen operation defined in
15
+ // native_functions.yaml; attributes are in one-to-one correspondence
16
+ // with their ATen name.
17
+
18
+ #define FORALL_ATEN_BASE_SYMBOLS(_) \
19
+ _(aten, __and__) \
20
+ _(aten, __iand__) \
21
+ _(aten, __ilshift__) \
22
+ _(aten, __ior__) \
23
+ _(aten, __irshift__) \
24
+ _(aten, __ixor__) \
25
+ _(aten, __lshift__) \
26
+ _(aten, __or__) \
27
+ _(aten, __rshift__) \
28
+ _(aten, __xor__) \
29
+ _(aten, _adaptive_avg_pool2d) \
30
+ _(aten, _adaptive_avg_pool2d_backward) \
31
+ _(aten, _adaptive_avg_pool3d) \
32
+ _(aten, _adaptive_avg_pool3d_backward) \
33
+ _(aten, _add_batch_dim) \
34
+ _(aten, _add_relu) \
35
+ _(aten, _add_relu_) \
36
+ _(aten, _addmm_activation) \
37
+ _(aten, _aminmax) \
38
+ _(aten, _amp_foreach_non_finite_check_and_unscale) \
39
+ _(aten, _amp_foreach_non_finite_check_and_unscale_) \
40
+ _(aten, _amp_update_scale) \
41
+ _(aten, _amp_update_scale_) \
42
+ _(aten, _assert_async) \
43
+ _(aten, _assert_scalar) \
44
+ _(aten, _assert_tensor_metadata) \
45
+ _(aten, _autocast_to_full_precision) \
46
+ _(aten, _autocast_to_reduced_precision) \
47
+ _(aten, _backward) \
48
+ _(aten, _batch_norm_impl_index) \
49
+ _(aten, _batch_norm_impl_index_backward) \
50
+ _(aten, _batch_norm_no_update) \
51
+ _(aten, _batch_norm_with_update) \
52
+ _(aten, _batch_norm_with_update_functional) \
53
+ _(aten, _cast_Byte) \
54
+ _(aten, _cast_Char) \
55
+ _(aten, _cast_Double) \
56
+ _(aten, _cast_Float) \
57
+ _(aten, _cast_Half) \
58
+ _(aten, _cast_Int) \
59
+ _(aten, _cast_Long) \
60
+ _(aten, _cast_Short) \
61
+ _(aten, _cdist_backward) \
62
+ _(aten, _cdist_forward) \
63
+ _(aten, _cholesky_solve_helper) \
64
+ _(aten, _choose_qparams_per_tensor) \
65
+ _(aten, _chunk_cat) \
66
+ _(aten, _coalesce) \
67
+ _(aten, _coalesced) \
68
+ _(aten, _coalesced_) \
69
+ _(aten, _compute_linear_combination) \
70
+ _(aten, _conj) \
71
+ _(aten, _conj_copy) \
72
+ _(aten, _conj_physical) \
73
+ _(aten, _conv_depthwise2d) \
74
+ _(aten, _convert_indices_from_coo_to_csr) \
75
+ _(aten, _convert_indices_from_csr_to_coo) \
76
+ _(aten, _convert_weight_to_int4pack) \
77
+ _(aten, _convert_weight_to_int4pack_for_cpu) \
78
+ _(aten, _convolution) \
79
+ _(aten, _convolution_double_backward) \
80
+ _(aten, _convolution_mode) \
81
+ _(aten, _copy_from) \
82
+ _(aten, _copy_from_and_resize) \
83
+ _(aten, _cslt_compress) \
84
+ _(aten, _cslt_sparse_mm) \
85
+ _(aten, _cslt_sparse_mm_search) \
86
+ _(aten, _ctc_loss) \
87
+ _(aten, _ctc_loss_backward) \
88
+ _(aten, _cudnn_attention_forward) \
89
+ _(aten, _cudnn_ctc_loss) \
90
+ _(aten, _cudnn_init_dropout_state) \
91
+ _(aten, _cudnn_rnn) \
92
+ _(aten, _cudnn_rnn_backward) \
93
+ _(aten, _cudnn_rnn_flatten_weight) \
94
+ _(aten, _cufft_clear_plan_cache) \
95
+ _(aten, _cufft_get_plan_cache_max_size) \
96
+ _(aten, _cufft_get_plan_cache_size) \
97
+ _(aten, _cufft_set_plan_cache_max_size) \
98
+ _(aten, _cummax_helper) \
99
+ _(aten, _cummin_helper) \
100
+ _(aten, _debug_has_internal_overlap) \
101
+ _(aten, _dimI) \
102
+ _(aten, _dimV) \
103
+ _(aten, _dim_arange) \
104
+ _(aten, _dirichlet_grad) \
105
+ _(aten, _dyn_quant_matmul_4bit) \
106
+ _(aten, _dyn_quant_pack_4bit_weight) \
107
+ _(aten, _efficient_attention_backward) \
108
+ _(aten, _efficient_attention_forward) \
109
+ _(aten, _efficientzerotensor) \
110
+ _(aten, _embedding_bag) \
111
+ _(aten, _embedding_bag_backward) \
112
+ _(aten, _embedding_bag_dense_backward) \
113
+ _(aten, _embedding_bag_forward_only) \
114
+ _(aten, _embedding_bag_per_sample_weights_backward) \
115
+ _(aten, _embedding_bag_sparse_backward) \
116
+ _(aten, _empty_affine_quantized) \
117
+ _(aten, _empty_per_channel_affine_quantized) \
118
+ _(aten, _euclidean_dist) \
119
+ _(aten, _fake_quantize_learnable_per_channel_affine) \
120
+ _(aten, _fake_quantize_learnable_per_channel_affine_backward) \
121
+ _(aten, _fake_quantize_learnable_per_tensor_affine) \
122
+ _(aten, _fake_quantize_learnable_per_tensor_affine_backward) \
123
+ _(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \
124
+ _(aten, _fft_c2c) \
125
+ _(aten, _fft_c2r) \
126
+ _(aten, _fft_r2c) \
127
+ _(aten, _fill_mem_eff_dropout_mask) \
128
+ _(aten, _fill_mem_eff_dropout_mask_) \
129
+ _(aten, _flash_attention_backward) \
130
+ _(aten, _flash_attention_forward) \
131
+ _(aten, _foobar) \
132
+ _(aten, _foreach_abs) \
133
+ _(aten, _foreach_abs_) \
134
+ _(aten, _foreach_acos) \
135
+ _(aten, _foreach_acos_) \
136
+ _(aten, _foreach_add) \
137
+ _(aten, _foreach_add_) \
138
+ _(aten, _foreach_addcdiv) \
139
+ _(aten, _foreach_addcdiv_) \
140
+ _(aten, _foreach_addcmul) \
141
+ _(aten, _foreach_addcmul_) \
142
+ _(aten, _foreach_asin) \
143
+ _(aten, _foreach_asin_) \
144
+ _(aten, _foreach_atan) \
145
+ _(aten, _foreach_atan_) \
146
+ _(aten, _foreach_ceil) \
147
+ _(aten, _foreach_ceil_) \
148
+ _(aten, _foreach_clamp_max) \
149
+ _(aten, _foreach_clamp_max_) \
150
+ _(aten, _foreach_clamp_min) \
151
+ _(aten, _foreach_clamp_min_) \
152
+ _(aten, _foreach_copy) \
153
+ _(aten, _foreach_copy_) \
154
+ _(aten, _foreach_cos) \
155
+ _(aten, _foreach_cos_) \
156
+ _(aten, _foreach_cosh) \
157
+ _(aten, _foreach_cosh_) \
158
+ _(aten, _foreach_div) \
159
+ _(aten, _foreach_div_) \
160
+ _(aten, _foreach_erf) \
161
+ _(aten, _foreach_erf_) \
162
+ _(aten, _foreach_erfc) \
163
+ _(aten, _foreach_erfc_) \
164
+ _(aten, _foreach_exp) \
165
+ _(aten, _foreach_exp_) \
166
+ _(aten, _foreach_expm1) \
167
+ _(aten, _foreach_expm1_) \
168
+ _(aten, _foreach_floor) \
169
+ _(aten, _foreach_floor_) \
170
+ _(aten, _foreach_frac) \
171
+ _(aten, _foreach_frac_) \
172
+ _(aten, _foreach_lerp) \
173
+ _(aten, _foreach_lerp_) \
174
+ _(aten, _foreach_lgamma) \
175
+ _(aten, _foreach_lgamma_) \
176
+ _(aten, _foreach_log) \
177
+ _(aten, _foreach_log10) \
178
+ _(aten, _foreach_log10_) \
179
+ _(aten, _foreach_log1p) \
180
+ _(aten, _foreach_log1p_) \
181
+ _(aten, _foreach_log2) \
182
+ _(aten, _foreach_log2_) \
183
+ _(aten, _foreach_log_) \
184
+ _(aten, _foreach_max) \
185
+ _(aten, _foreach_maximum) \
186
+ _(aten, _foreach_maximum_) \
187
+ _(aten, _foreach_minimum) \
188
+ _(aten, _foreach_minimum_) \
189
+ _(aten, _foreach_mul) \
190
+ _(aten, _foreach_mul_) \
191
+ _(aten, _foreach_neg) \
192
+ _(aten, _foreach_neg_) \
193
+ _(aten, _foreach_norm) \
194
+ _(aten, _foreach_pow) \
195
+ _(aten, _foreach_pow_) \
196
+ _(aten, _foreach_reciprocal) \
197
+ _(aten, _foreach_reciprocal_) \
198
+ _(aten, _foreach_round) \
199
+ _(aten, _foreach_round_) \
200
+ _(aten, _foreach_rsqrt) \
201
+ _(aten, _foreach_rsqrt_) \
202
+ _(aten, _foreach_sigmoid) \
203
+ _(aten, _foreach_sigmoid_) \
204
+ _(aten, _foreach_sign) \
205
+ _(aten, _foreach_sign_) \
206
+ _(aten, _foreach_sin) \
207
+ _(aten, _foreach_sin_) \
208
+ _(aten, _foreach_sinh) \
209
+ _(aten, _foreach_sinh_) \
210
+ _(aten, _foreach_sqrt) \
211
+ _(aten, _foreach_sqrt_) \
212
+ _(aten, _foreach_sub) \
213
+ _(aten, _foreach_sub_) \
214
+ _(aten, _foreach_tan) \
215
+ _(aten, _foreach_tan_) \
216
+ _(aten, _foreach_tanh) \
217
+ _(aten, _foreach_tanh_) \
218
+ _(aten, _foreach_trunc) \
219
+ _(aten, _foreach_trunc_) \
220
+ _(aten, _foreach_zero) \
221
+ _(aten, _foreach_zero_) \
222
+ _(aten, _functional_assert_async) \
223
+ _(aten, _functional_assert_scalar) \
224
+ _(aten, _functional_sym_constrain_range) \
225
+ _(aten, _functional_sym_constrain_range_for_size) \
226
+ _(aten, _fused_adagrad) \
227
+ _(aten, _fused_adagrad_) \
228
+ _(aten, _fused_adam) \
229
+ _(aten, _fused_adam_) \
230
+ _(aten, _fused_adamw) \
231
+ _(aten, _fused_adamw_) \
232
+ _(aten, _fused_dropout) \
233
+ _(aten, _fused_moving_avg_obs_fq_helper) \
234
+ _(aten, _fused_moving_avg_obs_fq_helper_functional) \
235
+ _(aten, _fused_rms_norm) \
236
+ _(aten, _fused_sdp_choice) \
237
+ _(aten, _fused_sgd) \
238
+ _(aten, _fused_sgd_) \
239
+ _(aten, _fw_primal) \
240
+ _(aten, _fw_primal_copy) \
241
+ _(aten, _gather_sparse_backward) \
242
+ _(aten, _grid_sampler_2d_cpu_fallback) \
243
+ _(aten, _grid_sampler_2d_cpu_fallback_backward) \
244
+ _(aten, _grouped_mm) \
245
+ _(aten, _has_compatible_shallow_copy_type) \
246
+ _(aten, _has_same_storage_numel) \
247
+ _(aten, _histogramdd_bin_edges) \
248
+ _(aten, _histogramdd_from_bin_cts) \
249
+ _(aten, _histogramdd_from_bin_tensors) \
250
+ _(aten, _index_put_impl) \
251
+ _(aten, _index_put_impl_) \
252
+ _(aten, _indices) \
253
+ _(aten, _indices_copy) \
254
+ _(aten, _int_mm) \
255
+ _(aten, _is_all_true) \
256
+ _(aten, _is_any_true) \
257
+ _(aten, _is_zerotensor) \
258
+ _(aten, _jagged_to_padded_dense_forward) \
259
+ _(aten, _lazy_clone) \
260
+ _(aten, _linalg_check_errors) \
261
+ _(aten, _linalg_det) \
262
+ _(aten, _linalg_eigh) \
263
+ _(aten, _linalg_eigvals) \
264
+ _(aten, _linalg_slogdet) \
265
+ _(aten, _linalg_solve_ex) \
266
+ _(aten, _linalg_svd) \
267
+ _(aten, _local_scalar_dense) \
268
+ _(aten, _log_softmax) \
269
+ _(aten, _log_softmax_backward_data) \
270
+ _(aten, _logcumsumexp) \
271
+ _(aten, _lstm_mps) \
272
+ _(aten, _lu_with_info) \
273
+ _(aten, _make_dep_token) \
274
+ _(aten, _make_dual) \
275
+ _(aten, _make_dual_copy) \
276
+ _(aten, _make_per_channel_quantized_tensor) \
277
+ _(aten, _make_per_tensor_quantized_tensor) \
278
+ _(aten, _masked_scale) \
279
+ _(aten, _masked_softmax) \
280
+ _(aten, _masked_softmax_backward) \
281
+ _(aten, _mixed_dtypes_linear) \
282
+ _(aten, _mkldnn_reshape) \
283
+ _(aten, _mkldnn_transpose) \
284
+ _(aten, _mkldnn_transpose_) \
285
+ _(aten, _mps_convolution) \
286
+ _(aten, _mps_convolution_transpose) \
287
+ _(aten, _native_batch_norm_legit) \
288
+ _(aten, _native_batch_norm_legit_functional) \
289
+ _(aten, _native_batch_norm_legit_no_training) \
290
+ _(aten, _native_multi_head_attention) \
291
+ _(aten, _neg_view) \
292
+ _(aten, _neg_view_copy) \
293
+ _(aten, _nested_compute_contiguous_strides_offsets) \
294
+ _(aten, _nested_from_padded) \
295
+ _(aten, _nested_from_padded_and_nested_example) \
296
+ _(aten, _nested_from_padded_tensor) \
297
+ _(aten, _nested_get_jagged_dummy) \
298
+ _(aten, _nested_get_lengths) \
299
+ _(aten, _nested_get_max_seqlen) \
300
+ _(aten, _nested_get_min_seqlen) \
301
+ _(aten, _nested_get_offsets) \
302
+ _(aten, _nested_get_ragged_idx) \
303
+ _(aten, _nested_get_values) \
304
+ _(aten, _nested_get_values_copy) \
305
+ _(aten, _nested_select_backward) \
306
+ _(aten, _nested_sum_backward) \
307
+ _(aten, _nested_tensor_from_mask) \
308
+ _(aten, _nested_tensor_from_mask_left_aligned) \
309
+ _(aten, _nested_tensor_from_tensor_list) \
310
+ _(aten, _nested_tensor_size) \
311
+ _(aten, _nested_tensor_softmax_with_shape) \
312
+ _(aten, _nested_tensor_storage_offsets) \
313
+ _(aten, _nested_tensor_strides) \
314
+ _(aten, _nested_view_from_buffer) \
315
+ _(aten, _nested_view_from_buffer_copy) \
316
+ _(aten, _nested_view_from_jagged) \
317
+ _(aten, _nested_view_from_jagged_copy) \
318
+ _(aten, _new_zeros_with_same_feature_meta) \
319
+ _(aten, _nnpack_available) \
320
+ _(aten, _nnpack_spatial_convolution) \
321
+ _(aten, _nnz) \
322
+ _(aten, _pack_padded_sequence) \
323
+ _(aten, _pack_padded_sequence_backward) \
324
+ _(aten, _pad_circular) \
325
+ _(aten, _pad_enum) \
326
+ _(aten, _pad_packed_sequence) \
327
+ _(aten, _padded_dense_to_jagged_forward) \
328
+ _(aten, _pdist_backward) \
329
+ _(aten, _pdist_forward) \
330
+ _(aten, _pin_memory) \
331
+ _(aten, _prelu_kernel) \
332
+ _(aten, _prelu_kernel_backward) \
333
+ _(aten, _print) \
334
+ _(aten, _propagate_xla_data) \
335
+ _(aten, _remove_batch_dim) \
336
+ _(aten, _reshape_alias) \
337
+ _(aten, _reshape_alias_copy) \
338
+ _(aten, _reshape_copy) \
339
+ _(aten, _reshape_from_tensor) \
340
+ _(aten, _resize_output) \
341
+ _(aten, _resize_output_) \
342
+ _(aten, _rowwise_prune) \
343
+ _(aten, _safe_softmax) \
344
+ _(aten, _sample_dirichlet) \
345
+ _(aten, _saturate_weight_to_fp16) \
346
+ _(aten, _scaled_dot_product_attention_math) \
347
+ _(aten, _scaled_dot_product_attention_math_for_mps) \
348
+ _(aten, _scaled_dot_product_cudnn_attention) \
349
+ _(aten, _scaled_dot_product_cudnn_attention_backward) \
350
+ _(aten, _scaled_dot_product_efficient_attention) \
351
+ _(aten, _scaled_dot_product_efficient_attention_backward) \
352
+ _(aten, _scaled_dot_product_flash_attention) \
353
+ _(aten, _scaled_dot_product_flash_attention_backward) \
354
+ _(aten, _scaled_dot_product_flash_attention_for_cpu) \
355
+ _(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \
356
+ _(aten, _scaled_dot_product_fused_attention_overrideable) \
357
+ _(aten, _scaled_dot_product_fused_attention_overrideable_backward) \
358
+ _(aten, _scaled_grouped_mm) \
359
+ _(aten, _scaled_mm) \
360
+ _(aten, _segment_reduce_backward) \
361
+ _(aten, _shape_as_tensor) \
362
+ _(aten, _slow_conv2d_backward) \
363
+ _(aten, _slow_conv2d_forward) \
364
+ _(aten, _sobol_engine_draw) \
365
+ _(aten, _sobol_engine_ff) \
366
+ _(aten, _sobol_engine_ff_) \
367
+ _(aten, _sobol_engine_initialize_state) \
368
+ _(aten, _sobol_engine_initialize_state_) \
369
+ _(aten, _sobol_engine_scramble) \
370
+ _(aten, _sobol_engine_scramble_) \
371
+ _(aten, _softmax) \
372
+ _(aten, _softmax_backward_data) \
373
+ _(aten, _sparse_addmm) \
374
+ _(aten, _sparse_broadcast_to) \
375
+ _(aten, _sparse_broadcast_to_copy) \
376
+ _(aten, _sparse_bsc_tensor_unsafe) \
377
+ _(aten, _sparse_bsr_tensor_unsafe) \
378
+ _(aten, _sparse_compressed_tensor_unsafe) \
379
+ _(aten, _sparse_compressed_tensor_with_dims) \
380
+ _(aten, _sparse_coo_tensor_unsafe) \
381
+ _(aten, _sparse_coo_tensor_with_dims) \
382
+ _(aten, _sparse_coo_tensor_with_dims_and_tensors) \
383
+ _(aten, _sparse_csc_tensor_unsafe) \
384
+ _(aten, _sparse_csr_prod) \
385
+ _(aten, _sparse_csr_sum) \
386
+ _(aten, _sparse_csr_tensor_unsafe) \
387
+ _(aten, _sparse_log_softmax) \
388
+ _(aten, _sparse_log_softmax_backward_data) \
389
+ _(aten, _sparse_mask_projection) \
390
+ _(aten, _sparse_mm) \
391
+ _(aten, _sparse_mm_reduce_impl) \
392
+ _(aten, _sparse_mm_reduce_impl_backward) \
393
+ _(aten, _sparse_semi_structured_addmm) \
394
+ _(aten, _sparse_semi_structured_apply) \
395
+ _(aten, _sparse_semi_structured_apply_dense) \
396
+ _(aten, _sparse_semi_structured_linear) \
397
+ _(aten, _sparse_semi_structured_mm) \
398
+ _(aten, _sparse_semi_structured_tile) \
399
+ _(aten, _sparse_softmax) \
400
+ _(aten, _sparse_softmax_backward_data) \
401
+ _(aten, _sparse_sparse_matmul) \
402
+ _(aten, _sparse_sum) \
403
+ _(aten, _sparse_sum_backward) \
404
+ _(aten, _spdiags) \
405
+ _(aten, _spsolve) \
406
+ _(aten, _stack) \
407
+ _(aten, _standard_gamma) \
408
+ _(aten, _standard_gamma_grad) \
409
+ _(aten, _test_ambiguous_defaults) \
410
+ _(aten, _test_autograd_multiple_dispatch) \
411
+ _(aten, _test_autograd_multiple_dispatch_view) \
412
+ _(aten, _test_autograd_multiple_dispatch_view_copy) \
413
+ _(aten, _test_check_tensor) \
414
+ _(aten, _test_functorch_fallback) \
415
+ _(aten, _test_optional_filled_intlist) \
416
+ _(aten, _test_optional_floatlist) \
417
+ _(aten, _test_optional_intlist) \
418
+ _(aten, _test_parallel_materialize) \
419
+ _(aten, _test_serialization_subcmul) \
420
+ _(aten, _test_string_default) \
421
+ _(aten, _test_warn_in_autograd) \
422
+ _(aten, _thnn_differentiable_gru_cell_backward) \
423
+ _(aten, _thnn_differentiable_lstm_cell_backward) \
424
+ _(aten, _thnn_fused_gru_cell) \
425
+ _(aten, _thnn_fused_gru_cell_backward) \
426
+ _(aten, _thnn_fused_lstm_cell) \
427
+ _(aten, _thnn_fused_lstm_cell_backward) \
428
+ _(aten, _thnn_fused_lstm_cell_backward_impl) \
429
+ _(aten, _to_copy) \
430
+ _(aten, _to_cpu) \
431
+ _(aten, _to_dense) \
432
+ _(aten, _to_sparse) \
433
+ _(aten, _to_sparse_bsc) \
434
+ _(aten, _to_sparse_bsr) \
435
+ _(aten, _to_sparse_csc) \
436
+ _(aten, _to_sparse_csr) \
437
+ _(aten, _to_sparse_semi_structured) \
438
+ _(aten, _transform_bias_rescale_qkv) \
439
+ _(aten, _transformer_encoder_layer_fwd) \
440
+ _(aten, _trilinear) \
441
+ _(aten, _triton_multi_head_attention) \
442
+ _(aten, _triton_scaled_dot_attention) \
443
+ _(aten, _unique) \
444
+ _(aten, _unique2) \
445
+ _(aten, _unpack_dual) \
446
+ _(aten, _unsafe_index) \
447
+ _(aten, _unsafe_index_put) \
448
+ _(aten, _unsafe_masked_index) \
449
+ _(aten, _unsafe_masked_index_put_accumulate) \
450
+ _(aten, _unsafe_view) \
451
+ _(aten, _upsample_bicubic2d_aa) \
452
+ _(aten, _upsample_bicubic2d_aa_backward) \
453
+ _(aten, _upsample_bilinear2d_aa) \
454
+ _(aten, _upsample_bilinear2d_aa_backward) \
455
+ _(aten, _upsample_nearest_exact1d) \
456
+ _(aten, _upsample_nearest_exact1d_backward) \
457
+ _(aten, _upsample_nearest_exact2d) \
458
+ _(aten, _upsample_nearest_exact2d_backward) \
459
+ _(aten, _upsample_nearest_exact3d) \
460
+ _(aten, _upsample_nearest_exact3d_backward) \
461
+ _(aten, _use_cudnn_ctc_loss) \
462
+ _(aten, _use_cudnn_rnn_flatten_weight) \
463
+ _(aten, _validate_compressed_sparse_indices) \
464
+ _(aten, _validate_sparse_bsc_tensor_args) \
465
+ _(aten, _validate_sparse_bsr_tensor_args) \
466
+ _(aten, _validate_sparse_compressed_tensor_args) \
467
+ _(aten, _validate_sparse_coo_tensor_args) \
468
+ _(aten, _validate_sparse_csc_tensor_args) \
469
+ _(aten, _validate_sparse_csr_tensor_args) \
470
+ _(aten, _values) \
471
+ _(aten, _values_copy) \
472
+ _(aten, _version) \
473
+ _(aten, _weight_int4pack_mm) \
474
+ _(aten, _weight_int4pack_mm_for_cpu) \
475
+ _(aten, _weight_int4pack_mm_with_scales_and_zeros) \
476
+ _(aten, _weight_int8pack_mm) \
477
+ _(aten, _weight_norm) \
478
+ _(aten, _weight_norm_differentiable_backward) \
479
+ _(aten, _weight_norm_interface) \
480
+ _(aten, _weight_norm_interface_backward) \
481
+ _(aten, _wrapped_linear_prepack) \
482
+ _(aten, _wrapped_quantized_linear_prepacked) \
483
+ _(aten, abs) \
484
+ _(aten, abs_) \
485
+ _(aten, absolute) \
486
+ _(aten, absolute_) \
487
+ _(aten, acos) \
488
+ _(aten, acos_) \
489
+ _(aten, acosh) \
490
+ _(aten, acosh_) \
491
+ _(aten, adaptive_avg_pool1d) \
492
+ _(aten, adaptive_avg_pool2d) \
493
+ _(aten, adaptive_avg_pool3d) \
494
+ _(aten, adaptive_avg_pool3d_backward) \
495
+ _(aten, adaptive_max_pool1d) \
496
+ _(aten, adaptive_max_pool2d) \
497
+ _(aten, adaptive_max_pool2d_backward) \
498
+ _(aten, adaptive_max_pool3d) \
499
+ _(aten, adaptive_max_pool3d_backward) \
500
+ _(aten, add) \
501
+ _(aten, add_) \
502
+ _(aten, addbmm) \
503
+ _(aten, addbmm_) \
504
+ _(aten, addcdiv) \
505
+ _(aten, addcdiv_) \
506
+ _(aten, addcmul) \
507
+ _(aten, addcmul_) \
508
+ _(aten, addmm) \
509
+ _(aten, addmm_) \
510
+ _(aten, addmv) \
511
+ _(aten, addmv_) \
512
+ _(aten, addr) \
513
+ _(aten, addr_) \
514
+ _(aten, adjoint) \
515
+ _(aten, affine_grid_generator) \
516
+ _(aten, affine_grid_generator_backward) \
517
+ _(aten, alias) \
518
+ _(aten, alias_copy) \
519
+ _(aten, align_as) \
520
+ _(aten, align_tensors) \
521
+ _(aten, align_to) \
522
+ _(aten, all) \
523
+ _(aten, allclose) \
524
+ _(aten, alpha_dropout) \
525
+ _(aten, alpha_dropout_) \
526
+ _(aten, amax) \
527
+ _(aten, amin) \
528
+ _(aten, aminmax) \
529
+ _(aten, angle) \
530
+ _(aten, any) \
531
+ _(aten, arange) \
532
+ _(aten, arccos) \
533
+ _(aten, arccos_) \
534
+ _(aten, arccosh) \
535
+ _(aten, arccosh_) \
536
+ _(aten, arcsin) \
537
+ _(aten, arcsin_) \
538
+ _(aten, arcsinh) \
539
+ _(aten, arcsinh_) \
540
+ _(aten, arctan) \
541
+ _(aten, arctan2) \
542
+ _(aten, arctan2_) \
543
+ _(aten, arctan_) \
544
+ _(aten, arctanh) \
545
+ _(aten, arctanh_) \
546
+ _(aten, argmax) \
547
+ _(aten, argmin) \
548
+ _(aten, argsort) \
549
+ _(aten, argwhere) \
550
+ _(aten, as_strided) \
551
+ _(aten, as_strided_) \
552
+ _(aten, as_strided_copy) \
553
+ _(aten, as_strided_scatter) \
554
+ _(aten, asin) \
555
+ _(aten, asin_) \
556
+ _(aten, asinh) \
557
+ _(aten, asinh_) \
558
+ _(aten, atan) \
559
+ _(aten, atan2) \
560
+ _(aten, atan2_) \
561
+ _(aten, atan_) \
562
+ _(aten, atanh) \
563
+ _(aten, atanh_) \
564
+ _(aten, atleast_1d) \
565
+ _(aten, atleast_2d) \
566
+ _(aten, atleast_3d) \
567
+ _(aten, avg_pool1d) \
568
+ _(aten, avg_pool2d) \
569
+ _(aten, avg_pool2d_backward) \
570
+ _(aten, avg_pool3d) \
571
+ _(aten, avg_pool3d_backward) \
572
+ _(aten, baddbmm) \
573
+ _(aten, baddbmm_) \
574
+ _(aten, bartlett_window) \
575
+ _(aten, batch_norm) \
576
+ _(aten, batch_norm_backward) \
577
+ _(aten, batch_norm_backward_elemt) \
578
+ _(aten, batch_norm_backward_reduce) \
579
+ _(aten, batch_norm_elemt) \
580
+ _(aten, batch_norm_gather_stats) \
581
+ _(aten, batch_norm_gather_stats_with_counts) \
582
+ _(aten, batch_norm_stats) \
583
+ _(aten, batch_norm_update_stats) \
584
+ _(aten, bernoulli) \
585
+ _(aten, bernoulli_) \
586
+ _(aten, bilinear) \
587
+ _(aten, binary_cross_entropy) \
588
+ _(aten, binary_cross_entropy_backward) \
589
+ _(aten, binary_cross_entropy_with_logits) \
590
+ _(aten, bincount) \
591
+ _(aten, binomial) \
592
+ _(aten, bitwise_and) \
593
+ _(aten, bitwise_and_) \
594
+ _(aten, bitwise_left_shift) \
595
+ _(aten, bitwise_left_shift_) \
596
+ _(aten, bitwise_not) \
597
+ _(aten, bitwise_not_) \
598
+ _(aten, bitwise_or) \
599
+ _(aten, bitwise_or_) \
600
+ _(aten, bitwise_right_shift) \
601
+ _(aten, bitwise_right_shift_) \
602
+ _(aten, bitwise_xor) \
603
+ _(aten, bitwise_xor_) \
604
+ _(aten, blackman_window) \
605
+ _(aten, block_diag) \
606
+ _(aten, bmm) \
607
+ _(aten, broadcast_tensors) \
608
+ _(aten, broadcast_to) \
609
+ _(aten, bucketize) \
610
+ _(aten, can_cast) \
611
+ _(aten, cartesian_prod) \
612
+ _(aten, cat) \
613
+ _(aten, cauchy) \
614
+ _(aten, cauchy_) \
615
+ _(aten, ccol_indices) \
616
+ _(aten, ccol_indices_copy) \
617
+ _(aten, cdist) \
618
+ _(aten, ceil) \
619
+ _(aten, ceil_) \
620
+ _(aten, celu) \
621
+ _(aten, celu_) \
622
+ _(aten, chain_matmul) \
623
+ _(aten, chalf) \
624
+ _(aten, channel_shuffle) \
625
+ _(aten, cholesky) \
626
+ _(aten, cholesky_inverse) \
627
+ _(aten, cholesky_solve) \
628
+ _(aten, choose_qparams_optimized) \
629
+ _(aten, chunk) \
630
+ _(aten, clamp) \
631
+ _(aten, clamp_) \
632
+ _(aten, clamp_max) \
633
+ _(aten, clamp_max_) \
634
+ _(aten, clamp_min) \
635
+ _(aten, clamp_min_) \
636
+ _(aten, clip) \
637
+ _(aten, clip_) \
638
+ _(aten, clone) \
639
+ _(aten, coalesce) \
640
+ _(aten, col2im) \
641
+ _(aten, col_indices) \
642
+ _(aten, col_indices_copy) \
643
+ _(aten, column_stack) \
644
+ _(aten, combinations) \
645
+ _(aten, complex) \
646
+ _(aten, concat) \
647
+ _(aten, concatenate) \
648
+ _(aten, conj) \
649
+ _(aten, conj_physical) \
650
+ _(aten, conj_physical_) \
651
+ _(aten, constant_pad_nd) \
652
+ _(aten, contiguous) \
653
+ _(aten, conv1d) \
654
+ _(aten, conv2d) \
655
+ _(aten, conv3d) \
656
+ _(aten, conv_depthwise3d) \
657
+ _(aten, conv_tbc) \
658
+ _(aten, conv_tbc_backward) \
659
+ _(aten, conv_transpose1d) \
660
+ _(aten, conv_transpose2d) \
661
+ _(aten, conv_transpose3d) \
662
+ _(aten, convolution) \
663
+ _(aten, convolution_backward) \
664
+ _(aten, convolution_backward_overrideable) \
665
+ _(aten, convolution_overrideable) \
666
+ _(aten, copy) \
667
+ _(aten, copy_) \
668
+ _(aten, copy_sparse_to_sparse) \
669
+ _(aten, copy_sparse_to_sparse_) \
670
+ _(aten, copysign) \
671
+ _(aten, copysign_) \
672
+ _(aten, corrcoef) \
673
+ _(aten, cos) \
674
+ _(aten, cos_) \
675
+ _(aten, cosh) \
676
+ _(aten, cosh_) \
677
+ _(aten, cosine_embedding_loss) \
678
+ _(aten, cosine_similarity) \
679
+ _(aten, count_nonzero) \
680
+ _(aten, cov) \
681
+ _(aten, cross) \
682
+ _(aten, cross_entropy_loss) \
683
+ _(aten, crow_indices) \
684
+ _(aten, crow_indices_copy) \
685
+ _(aten, ctc_loss) \
686
+ _(aten, cudnn_affine_grid_generator) \
687
+ _(aten, cudnn_affine_grid_generator_backward) \
688
+ _(aten, cudnn_batch_norm) \
689
+ _(aten, cudnn_batch_norm_backward) \
690
+ _(aten, cudnn_convolution) \
691
+ _(aten, cudnn_convolution_add_relu) \
692
+ _(aten, cudnn_convolution_relu) \
693
+ _(aten, cudnn_convolution_transpose) \
694
+ _(aten, cudnn_grid_sampler) \
695
+ _(aten, cudnn_grid_sampler_backward) \
696
+ _(aten, cudnn_is_acceptable) \
697
+ _(aten, cummax) \
698
+ _(aten, cummaxmin_backward) \
699
+ _(aten, cummin) \
700
+ _(aten, cumprod) \
701
+ _(aten, cumprod_) \
702
+ _(aten, cumprod_backward) \
703
+ _(aten, cumsum) \
704
+ _(aten, cumsum_) \
705
+ _(aten, cumulative_trapezoid) \
706
+ _(aten, data) \
707
+ _(aten, deg2rad) \
708
+ _(aten, deg2rad_) \
709
+ _(aten, dense_dim) \
710
+ _(aten, dequantize) \
711
+ _(aten, det) \
712
+ _(aten, detach) \
713
+ _(aten, detach_) \
714
+ _(aten, detach_copy) \
715
+ _(aten, diag) \
716
+ _(aten, diag_embed) \
717
+ _(aten, diagflat) \
718
+ _(aten, diagonal) \
719
+ _(aten, diagonal_backward) \
720
+ _(aten, diagonal_copy) \
721
+ _(aten, diagonal_scatter) \
722
+ _(aten, diff) \
723
+ _(aten, digamma) \
724
+ _(aten, digamma_) \
725
+ _(aten, dist) \
726
+ _(aten, div) \
727
+ _(aten, div_) \
728
+ _(aten, divide) \
729
+ _(aten, divide_) \
730
+ _(aten, dot) \
731
+ _(aten, dropout) \
732
+ _(aten, dropout_) \
733
+ _(aten, dsplit) \
734
+ _(aten, dstack) \
735
+ _(aten, einsum) \
736
+ _(aten, elu) \
737
+ _(aten, elu_) \
738
+ _(aten, elu_backward) \
739
+ _(aten, embedding) \
740
+ _(aten, embedding_backward) \
741
+ _(aten, embedding_bag) \
742
+ _(aten, embedding_dense_backward) \
743
+ _(aten, embedding_renorm) \
744
+ _(aten, embedding_renorm_) \
745
+ _(aten, embedding_sparse_backward) \
746
+ _(aten, empty) \
747
+ _(aten, empty_like) \
748
+ _(aten, empty_permuted) \
749
+ _(aten, empty_quantized) \
750
+ _(aten, empty_strided) \
751
+ _(aten, eq) \
752
+ _(aten, eq_) \
753
+ _(aten, equal) \
754
+ _(aten, erf) \
755
+ _(aten, erf_) \
756
+ _(aten, erfc) \
757
+ _(aten, erfc_) \
758
+ _(aten, erfinv) \
759
+ _(aten, erfinv_) \
760
+ _(aten, exp) \
761
+ _(aten, exp2) \
762
+ _(aten, exp2_) \
763
+ _(aten, exp_) \
764
+ _(aten, expand) \
765
+ _(aten, expand_as) \
766
+ _(aten, expand_copy) \
767
+ _(aten, expm1) \
768
+ _(aten, expm1_) \
769
+ _(aten, exponential) \
770
+ _(aten, exponential_) \
771
+ _(aten, eye) \
772
+ _(aten, fake_quantize_per_channel_affine) \
773
+ _(aten, fake_quantize_per_channel_affine_cachemask) \
774
+ _(aten, fake_quantize_per_channel_affine_cachemask_backward) \
775
+ _(aten, fake_quantize_per_tensor_affine) \
776
+ _(aten, fake_quantize_per_tensor_affine_cachemask) \
777
+ _(aten, fake_quantize_per_tensor_affine_cachemask_backward) \
778
+ _(aten, fbgemm_linear_fp16_weight) \
779
+ _(aten, fbgemm_linear_fp16_weight_fp32_activation) \
780
+ _(aten, fbgemm_linear_int8_weight) \
781
+ _(aten, fbgemm_linear_int8_weight_fp32_activation) \
782
+ _(aten, fbgemm_linear_quantize_weight) \
783
+ _(aten, fbgemm_pack_gemm_matrix_fp16) \
784
+ _(aten, fbgemm_pack_quantized_matrix) \
785
+ _(aten, feature_alpha_dropout) \
786
+ _(aten, feature_alpha_dropout_) \
787
+ _(aten, feature_dropout) \
788
+ _(aten, feature_dropout_) \
789
+ _(aten, fft_fft) \
790
+ _(aten, fft_fft2) \
791
+ _(aten, fft_fftfreq) \
792
+ _(aten, fft_fftn) \
793
+ _(aten, fft_fftshift) \
794
+ _(aten, fft_hfft) \
795
+ _(aten, fft_hfft2) \
796
+ _(aten, fft_hfftn) \
797
+ _(aten, fft_ifft) \
798
+ _(aten, fft_ifft2) \
799
+ _(aten, fft_ifftn) \
800
+ _(aten, fft_ifftshift) \
801
+ _(aten, fft_ihfft) \
802
+ _(aten, fft_ihfft2) \
803
+ _(aten, fft_ihfftn) \
804
+ _(aten, fft_irfft) \
805
+ _(aten, fft_irfft2) \
806
+ _(aten, fft_irfftn) \
807
+ _(aten, fft_rfft) \
808
+ _(aten, fft_rfft2) \
809
+ _(aten, fft_rfftfreq) \
810
+ _(aten, fft_rfftn) \
811
+ _(aten, fill) \
812
+ _(aten, fill_) \
813
+ _(aten, fill_diagonal) \
814
+ _(aten, fill_diagonal_) \
815
+ _(aten, fix) \
816
+ _(aten, fix_) \
817
+ _(aten, flatten) \
818
+ _(aten, flatten_dense_tensors) \
819
+ _(aten, flip) \
820
+ _(aten, fliplr) \
821
+ _(aten, flipud) \
822
+ _(aten, float_power) \
823
+ _(aten, float_power_) \
824
+ _(aten, floor) \
825
+ _(aten, floor_) \
826
+ _(aten, floor_divide) \
827
+ _(aten, floor_divide_) \
828
+ _(aten, fmax) \
829
+ _(aten, fmin) \
830
+ _(aten, fmod) \
831
+ _(aten, fmod_) \
832
+ _(aten, frac) \
833
+ _(aten, frac_) \
834
+ _(aten, fractional_max_pool2d) \
835
+ _(aten, fractional_max_pool2d_backward) \
836
+ _(aten, fractional_max_pool3d) \
837
+ _(aten, fractional_max_pool3d_backward) \
838
+ _(aten, frexp) \
839
+ _(aten, frobenius_norm) \
840
+ _(aten, from_file) \
841
+ _(aten, full) \
842
+ _(aten, full_like) \
843
+ _(aten, fused_moving_avg_obs_fake_quant) \
844
+ _(aten, gather) \
845
+ _(aten, gather_backward) \
846
+ _(aten, gcd) \
847
+ _(aten, gcd_) \
848
+ _(aten, ge) \
849
+ _(aten, ge_) \
850
+ _(aten, gelu) \
851
+ _(aten, gelu_) \
852
+ _(aten, gelu_backward) \
853
+ _(aten, geometric) \
854
+ _(aten, geometric_) \
855
+ _(aten, geqrf) \
856
+ _(aten, ger) \
857
+ _(aten, glu) \
858
+ _(aten, glu_backward) \
859
+ _(aten, glu_backward_jvp) \
860
+ _(aten, glu_jvp) \
861
+ _(aten, gradient) \
862
+ _(aten, greater) \
863
+ _(aten, greater_) \
864
+ _(aten, greater_equal) \
865
+ _(aten, greater_equal_) \
866
+ _(aten, grid_sampler) \
867
+ _(aten, grid_sampler_2d) \
868
+ _(aten, grid_sampler_2d_backward) \
869
+ _(aten, grid_sampler_3d) \
870
+ _(aten, grid_sampler_3d_backward) \
871
+ _(aten, group_norm) \
872
+ _(aten, gru) \
873
+ _(aten, gru_cell) \
874
+ _(aten, gt) \
875
+ _(aten, gt_) \
876
+ _(aten, hamming_window) \
877
+ _(aten, hann_window) \
878
+ _(aten, hardshrink) \
879
+ _(aten, hardshrink_backward) \
880
+ _(aten, hardsigmoid) \
881
+ _(aten, hardsigmoid_) \
882
+ _(aten, hardsigmoid_backward) \
883
+ _(aten, hardswish) \
884
+ _(aten, hardswish_) \
885
+ _(aten, hardswish_backward) \
886
+ _(aten, hardtanh) \
887
+ _(aten, hardtanh_) \
888
+ _(aten, hardtanh_backward) \
889
+ _(aten, heaviside) \
890
+ _(aten, heaviside_) \
891
+ _(aten, hinge_embedding_loss) \
892
+ _(aten, histc) \
893
+ _(aten, histogram) \
894
+ _(aten, histogramdd) \
895
+ _(aten, hsplit) \
896
+ _(aten, hspmm) \
897
+ _(aten, hstack) \
898
+ _(aten, huber_loss) \
899
+ _(aten, huber_loss_backward) \
900
+ _(aten, hypot) \
901
+ _(aten, hypot_) \
902
+ _(aten, i0) \
903
+ _(aten, i0_) \
904
+ _(aten, igamma) \
905
+ _(aten, igamma_) \
906
+ _(aten, igammac) \
907
+ _(aten, igammac_) \
908
+ _(aten, im2col) \
909
+ _(aten, imag) \
910
+ _(aten, index) \
911
+ _(aten, index_add) \
912
+ _(aten, index_add_) \
913
+ _(aten, index_copy) \
914
+ _(aten, index_copy_) \
915
+ _(aten, index_fill) \
916
+ _(aten, index_fill_) \
917
+ _(aten, index_put) \
918
+ _(aten, index_put_) \
919
+ _(aten, index_reduce) \
920
+ _(aten, index_reduce_) \
921
+ _(aten, index_select) \
922
+ _(aten, index_select_backward) \
923
+ _(aten, indices) \
924
+ _(aten, indices_copy) \
925
+ _(aten, infinitely_differentiable_gelu_backward) \
926
+ _(aten, inner) \
927
+ _(aten, instance_norm) \
928
+ _(aten, int_repr) \
929
+ _(aten, inverse) \
930
+ _(aten, is_coalesced) \
931
+ _(aten, is_complex) \
932
+ _(aten, is_conj) \
933
+ _(aten, is_distributed) \
934
+ _(aten, is_floating_point) \
935
+ _(aten, is_inference) \
936
+ _(aten, is_leaf) \
937
+ _(aten, is_neg) \
938
+ _(aten, is_nonzero) \
939
+ _(aten, is_pinned) \
940
+ _(aten, is_same_size) \
941
+ _(aten, is_set_to) \
942
+ _(aten, is_signed) \
943
+ _(aten, is_vulkan_available) \
944
+ _(aten, isclose) \
945
+ _(aten, isfinite) \
946
+ _(aten, isin) \
947
+ _(aten, isinf) \
948
+ _(aten, isnan) \
949
+ _(aten, isneginf) \
950
+ _(aten, isposinf) \
951
+ _(aten, isreal) \
952
+ _(aten, istft) \
953
+ _(aten, item) \
954
+ _(aten, kaiser_window) \
955
+ _(aten, kl_div) \
956
+ _(aten, kron) \
957
+ _(aten, kthvalue) \
958
+ _(aten, l1_loss) \
959
+ _(aten, layer_norm) \
960
+ _(aten, lcm) \
961
+ _(aten, lcm_) \
962
+ _(aten, ldexp) \
963
+ _(aten, ldexp_) \
964
+ _(aten, le) \
965
+ _(aten, le_) \
966
+ _(aten, leaky_relu) \
967
+ _(aten, leaky_relu_) \
968
+ _(aten, leaky_relu_backward) \
969
+ _(aten, lerp) \
970
+ _(aten, lerp_) \
971
+ _(aten, less) \
972
+ _(aten, less_) \
973
+ _(aten, less_equal) \
974
+ _(aten, less_equal_) \
975
+ _(aten, lgamma) \
976
+ _(aten, lgamma_) \
977
+ _(aten, lift) \
978
+ _(aten, lift_fresh) \
979
+ _(aten, lift_fresh_copy) \
980
+ _(aten, linalg_cholesky) \
981
+ _(aten, linalg_cholesky_ex) \
982
+ _(aten, linalg_cond) \
983
+ _(aten, linalg_cross) \
984
+ _(aten, linalg_det) \
985
+ _(aten, linalg_diagonal) \
986
+ _(aten, linalg_eig) \
987
+ _(aten, linalg_eigh) \
988
+ _(aten, linalg_eigvals) \
989
+ _(aten, linalg_eigvalsh) \
990
+ _(aten, linalg_householder_product) \
991
+ _(aten, linalg_inv) \
992
+ _(aten, linalg_inv_ex) \
993
+ _(aten, linalg_ldl_factor) \
994
+ _(aten, linalg_ldl_factor_ex) \
995
+ _(aten, linalg_ldl_solve) \
996
+ _(aten, linalg_lstsq) \
997
+ _(aten, linalg_lu) \
998
+ _(aten, linalg_lu_factor) \
999
+ _(aten, linalg_lu_factor_ex) \
1000
+ _(aten, linalg_lu_solve) \
1001
+ _(aten, linalg_matmul) \
1002
+ _(aten, linalg_matrix_exp) \
1003
+ _(aten, linalg_matrix_norm) \
1004
+ _(aten, linalg_matrix_power) \
1005
+ _(aten, linalg_matrix_rank) \
1006
+ _(aten, linalg_multi_dot) \
1007
+ _(aten, linalg_norm) \
1008
+ _(aten, linalg_pinv) \
1009
+ _(aten, linalg_qr) \
1010
+ _(aten, linalg_slogdet) \
1011
+ _(aten, linalg_solve) \
1012
+ _(aten, linalg_solve_ex) \
1013
+ _(aten, linalg_solve_triangular) \
1014
+ _(aten, linalg_svd) \
1015
+ _(aten, linalg_svdvals) \
1016
+ _(aten, linalg_tensorinv) \
1017
+ _(aten, linalg_tensorsolve) \
1018
+ _(aten, linalg_vander) \
1019
+ _(aten, linalg_vecdot) \
1020
+ _(aten, linalg_vector_norm) \
1021
+ _(aten, linear) \
1022
+ _(aten, linear_backward) \
1023
+ _(aten, linspace) \
1024
+ _(aten, log) \
1025
+ _(aten, log10) \
1026
+ _(aten, log10_) \
1027
+ _(aten, log1p) \
1028
+ _(aten, log1p_) \
1029
+ _(aten, log2) \
1030
+ _(aten, log2_) \
1031
+ _(aten, log_) \
1032
+ _(aten, log_normal) \
1033
+ _(aten, log_normal_) \
1034
+ _(aten, log_sigmoid) \
1035
+ _(aten, log_sigmoid_backward) \
1036
+ _(aten, log_sigmoid_forward) \
1037
+ _(aten, log_softmax) \
1038
+ _(aten, logaddexp) \
1039
+ _(aten, logaddexp2) \
1040
+ _(aten, logcumsumexp) \
1041
+ _(aten, logdet) \
1042
+ _(aten, logical_and) \
1043
+ _(aten, logical_and_) \
1044
+ _(aten, logical_not) \
1045
+ _(aten, logical_not_) \
1046
+ _(aten, logical_or) \
1047
+ _(aten, logical_or_) \
1048
+ _(aten, logical_xor) \
1049
+ _(aten, logical_xor_) \
1050
+ _(aten, logit) \
1051
+ _(aten, logit_) \
1052
+ _(aten, logit_backward) \
1053
+ _(aten, logspace) \
1054
+ _(aten, logsumexp) \
1055
+ _(aten, lshift) \
1056
+ _(aten, lstm) \
1057
+ _(aten, lstm_cell) \
1058
+ _(aten, lstm_mps_backward) \
1059
+ _(aten, lt) \
1060
+ _(aten, lt_) \
1061
+ _(aten, lu_solve) \
1062
+ _(aten, lu_unpack) \
1063
+ _(aten, mH) \
1064
+ _(aten, mT) \
1065
+ _(aten, margin_ranking_loss) \
1066
+ _(aten, masked_fill) \
1067
+ _(aten, masked_fill_) \
1068
+ _(aten, masked_scatter) \
1069
+ _(aten, masked_scatter_) \
1070
+ _(aten, masked_scatter_backward) \
1071
+ _(aten, masked_select) \
1072
+ _(aten, masked_select_backward) \
1073
+ _(aten, matmul) \
1074
+ _(aten, matmul_backward) \
1075
+ _(aten, matrix_H) \
1076
+ _(aten, matrix_exp) \
1077
+ _(aten, matrix_exp_backward) \
1078
+ _(aten, matrix_power) \
1079
+ _(aten, max) \
1080
+ _(aten, max_pool1d) \
1081
+ _(aten, max_pool1d_with_indices) \
1082
+ _(aten, max_pool2d) \
1083
+ _(aten, max_pool2d_backward) \
1084
+ _(aten, max_pool2d_with_indices) \
1085
+ _(aten, max_pool2d_with_indices_backward) \
1086
+ _(aten, max_pool3d) \
1087
+ _(aten, max_pool3d_with_indices) \
1088
+ _(aten, max_pool3d_with_indices_backward) \
1089
+ _(aten, max_unpool2d) \
1090
+ _(aten, max_unpool3d) \
1091
+ _(aten, maximum) \
1092
+ _(aten, mean) \
1093
+ _(aten, median) \
1094
+ _(aten, meshgrid) \
1095
+ _(aten, min) \
1096
+ _(aten, minimum) \
1097
+ _(aten, miopen_batch_norm) \
1098
+ _(aten, miopen_batch_norm_backward) \
1099
+ _(aten, miopen_convolution) \
1100
+ _(aten, miopen_convolution_add_relu) \
1101
+ _(aten, miopen_convolution_relu) \
1102
+ _(aten, miopen_convolution_transpose) \
1103
+ _(aten, miopen_depthwise_convolution) \
1104
+ _(aten, miopen_rnn) \
1105
+ _(aten, miopen_rnn_backward) \
1106
+ _(aten, mish) \
1107
+ _(aten, mish_) \
1108
+ _(aten, mish_backward) \
1109
+ _(aten, mkldnn_adaptive_avg_pool2d) \
1110
+ _(aten, mkldnn_adaptive_avg_pool2d_backward) \
1111
+ _(aten, mkldnn_convolution) \
1112
+ _(aten, mkldnn_linear) \
1113
+ _(aten, mkldnn_linear_backward) \
1114
+ _(aten, mkldnn_linear_backward_input) \
1115
+ _(aten, mkldnn_linear_backward_weights) \
1116
+ _(aten, mkldnn_max_pool2d) \
1117
+ _(aten, mkldnn_max_pool2d_backward) \
1118
+ _(aten, mkldnn_max_pool3d) \
1119
+ _(aten, mkldnn_max_pool3d_backward) \
1120
+ _(aten, mkldnn_reorder_conv2d_weight) \
1121
+ _(aten, mkldnn_reorder_conv3d_weight) \
1122
+ _(aten, mkldnn_rnn_layer) \
1123
+ _(aten, mkldnn_rnn_layer_backward) \
1124
+ _(aten, mm) \
1125
+ _(aten, mode) \
1126
+ _(aten, moveaxis) \
1127
+ _(aten, movedim) \
1128
+ _(aten, mps_convolution_backward) \
1129
+ _(aten, mps_convolution_transpose_backward) \
1130
+ _(aten, mse_loss) \
1131
+ _(aten, mse_loss_backward) \
1132
+ _(aten, msort) \
1133
+ _(aten, mul) \
1134
+ _(aten, mul_) \
1135
+ _(aten, multi_margin_loss) \
1136
+ _(aten, multi_margin_loss_backward) \
1137
+ _(aten, multilabel_margin_loss) \
1138
+ _(aten, multilabel_margin_loss_backward) \
1139
+ _(aten, multilabel_margin_loss_forward) \
1140
+ _(aten, multinomial) \
1141
+ _(aten, multiply) \
1142
+ _(aten, multiply_) \
1143
+ _(aten, mv) \
1144
+ _(aten, mvlgamma) \
1145
+ _(aten, mvlgamma_) \
1146
+ _(aten, nan_to_num) \
1147
+ _(aten, nan_to_num_) \
1148
+ _(aten, nanmean) \
1149
+ _(aten, nanmedian) \
1150
+ _(aten, nanquantile) \
1151
+ _(aten, nansum) \
1152
+ _(aten, narrow) \
1153
+ _(aten, narrow_copy) \
1154
+ _(aten, native_batch_norm) \
1155
+ _(aten, native_batch_norm_backward) \
1156
+ _(aten, native_channel_shuffle) \
1157
+ _(aten, native_dropout) \
1158
+ _(aten, native_dropout_backward) \
1159
+ _(aten, native_group_norm) \
1160
+ _(aten, native_group_norm_backward) \
1161
+ _(aten, native_layer_norm) \
1162
+ _(aten, native_layer_norm_backward) \
1163
+ _(aten, native_norm) \
1164
+ _(aten, ne) \
1165
+ _(aten, ne_) \
1166
+ _(aten, neg) \
1167
+ _(aten, neg_) \
1168
+ _(aten, negative) \
1169
+ _(aten, negative_) \
1170
+ _(aten, nested_to_padded_tensor) \
1171
+ _(aten, new_empty) \
1172
+ _(aten, new_empty_strided) \
1173
+ _(aten, new_full) \
1174
+ _(aten, new_ones) \
1175
+ _(aten, new_zeros) \
1176
+ _(aten, nextafter) \
1177
+ _(aten, nextafter_) \
1178
+ _(aten, nll_loss) \
1179
+ _(aten, nll_loss2d) \
1180
+ _(aten, nll_loss2d_backward) \
1181
+ _(aten, nll_loss2d_forward) \
1182
+ _(aten, nll_loss_backward) \
1183
+ _(aten, nll_loss_forward) \
1184
+ _(aten, nll_loss_nd) \
1185
+ _(aten, nonzero) \
1186
+ _(aten, nonzero_numpy) \
1187
+ _(aten, nonzero_static) \
1188
+ _(aten, norm) \
1189
+ _(aten, norm_except_dim) \
1190
+ _(aten, normal) \
1191
+ _(aten, normal_) \
1192
+ _(aten, normal_functional) \
1193
+ _(aten, not_equal) \
1194
+ _(aten, not_equal_) \
1195
+ _(aten, nuclear_norm) \
1196
+ _(aten, numpy_T) \
1197
+ _(aten, one_hot) \
1198
+ _(aten, ones) \
1199
+ _(aten, ones_like) \
1200
+ _(aten, orgqr) \
1201
+ _(aten, ormqr) \
1202
+ _(aten, outer) \
1203
+ _(aten, output_nr) \
1204
+ _(aten, pad) \
1205
+ _(aten, pad_sequence) \
1206
+ _(aten, pairwise_distance) \
1207
+ _(aten, pdist) \
1208
+ _(aten, permute) \
1209
+ _(aten, permute_copy) \
1210
+ _(aten, pin_memory) \
1211
+ _(aten, pinverse) \
1212
+ _(aten, pixel_shuffle) \
1213
+ _(aten, pixel_unshuffle) \
1214
+ _(aten, poisson) \
1215
+ _(aten, poisson_nll_loss) \
1216
+ _(aten, polar) \
1217
+ _(aten, polygamma) \
1218
+ _(aten, polygamma_) \
1219
+ _(aten, positive) \
1220
+ _(aten, pow) \
1221
+ _(aten, pow_) \
1222
+ _(aten, prelu) \
1223
+ _(aten, prod) \
1224
+ _(aten, promote_types) \
1225
+ _(aten, put) \
1226
+ _(aten, put_) \
1227
+ _(aten, q_per_channel_axis) \
1228
+ _(aten, q_per_channel_scales) \
1229
+ _(aten, q_per_channel_zero_points) \
1230
+ _(aten, q_scale) \
1231
+ _(aten, q_zero_point) \
1232
+ _(aten, qr) \
1233
+ _(aten, qscheme) \
1234
+ _(aten, quantile) \
1235
+ _(aten, quantize_per_channel) \
1236
+ _(aten, quantize_per_tensor) \
1237
+ _(aten, quantize_per_tensor_dynamic) \
1238
+ _(aten, quantized_batch_norm) \
1239
+ _(aten, quantized_gru_cell) \
1240
+ _(aten, quantized_lstm_cell) \
1241
+ _(aten, quantized_max_pool1d) \
1242
+ _(aten, quantized_max_pool2d) \
1243
+ _(aten, quantized_max_pool3d) \
1244
+ _(aten, quantized_rnn_relu_cell) \
1245
+ _(aten, quantized_rnn_tanh_cell) \
1246
+ _(aten, rad2deg) \
1247
+ _(aten, rad2deg_) \
1248
+ _(aten, rand) \
1249
+ _(aten, rand_like) \
1250
+ _(aten, randint) \
1251
+ _(aten, randint_like) \
1252
+ _(aten, randn) \
1253
+ _(aten, randn_like) \
1254
+ _(aten, random) \
1255
+ _(aten, random_) \
1256
+ _(aten, randperm) \
1257
+ _(aten, range) \
1258
+ _(aten, ravel) \
1259
+ _(aten, real) \
1260
+ _(aten, reciprocal) \
1261
+ _(aten, reciprocal_) \
1262
+ _(aten, record_stream) \
1263
+ _(aten, refine_names) \
1264
+ _(aten, reflection_pad1d) \
1265
+ _(aten, reflection_pad1d_backward) \
1266
+ _(aten, reflection_pad2d) \
1267
+ _(aten, reflection_pad2d_backward) \
1268
+ _(aten, reflection_pad3d) \
1269
+ _(aten, reflection_pad3d_backward) \
1270
+ _(aten, relu) \
1271
+ _(aten, relu6) \
1272
+ _(aten, relu6_) \
1273
+ _(aten, relu_) \
1274
+ _(aten, remainder) \
1275
+ _(aten, remainder_) \
1276
+ _(aten, rename) \
1277
+ _(aten, rename_) \
1278
+ _(aten, renorm) \
1279
+ _(aten, renorm_) \
1280
+ _(aten, repeat) \
1281
+ _(aten, repeat_interleave) \
1282
+ _(aten, replication_pad1d) \
1283
+ _(aten, replication_pad1d_backward) \
1284
+ _(aten, replication_pad2d) \
1285
+ _(aten, replication_pad2d_backward) \
1286
+ _(aten, replication_pad3d) \
1287
+ _(aten, replication_pad3d_backward) \
1288
+ _(aten, requires_grad) \
1289
+ _(aten, requires_grad_) \
1290
+ _(aten, reshape) \
1291
+ _(aten, reshape_as) \
1292
+ _(aten, resize) \
1293
+ _(aten, resize_) \
1294
+ _(aten, resize_as) \
1295
+ _(aten, resize_as_) \
1296
+ _(aten, resize_as_sparse) \
1297
+ _(aten, resize_as_sparse_) \
1298
+ _(aten, resolve_conj) \
1299
+ _(aten, resolve_neg) \
1300
+ _(aten, result_type) \
1301
+ _(aten, retain_grad) \
1302
+ _(aten, retains_grad) \
1303
+ _(aten, rms_norm) \
1304
+ _(aten, rnn_relu) \
1305
+ _(aten, rnn_relu_cell) \
1306
+ _(aten, rnn_tanh) \
1307
+ _(aten, rnn_tanh_cell) \
1308
+ _(aten, roll) \
1309
+ _(aten, rot90) \
1310
+ _(aten, round) \
1311
+ _(aten, round_) \
1312
+ _(aten, row_indices) \
1313
+ _(aten, row_indices_copy) \
1314
+ _(aten, row_stack) \
1315
+ _(aten, rrelu) \
1316
+ _(aten, rrelu_) \
1317
+ _(aten, rrelu_with_noise) \
1318
+ _(aten, rrelu_with_noise_) \
1319
+ _(aten, rrelu_with_noise_backward) \
1320
+ _(aten, rrelu_with_noise_functional) \
1321
+ _(aten, rshift) \
1322
+ _(aten, rsqrt) \
1323
+ _(aten, rsqrt_) \
1324
+ _(aten, rsub) \
1325
+ _(aten, scalar_tensor) \
1326
+ _(aten, scaled_dot_product_attention) \
1327
+ _(aten, scatter) \
1328
+ _(aten, scatter_) \
1329
+ _(aten, scatter_add) \
1330
+ _(aten, scatter_add_) \
1331
+ _(aten, scatter_reduce) \
1332
+ _(aten, scatter_reduce_) \
1333
+ _(aten, searchsorted) \
1334
+ _(aten, segment_reduce) \
1335
+ _(aten, select) \
1336
+ _(aten, select_backward) \
1337
+ _(aten, select_copy) \
1338
+ _(aten, select_scatter) \
1339
+ _(aten, selu) \
1340
+ _(aten, selu_) \
1341
+ _(aten, set) \
1342
+ _(aten, set_) \
1343
+ _(aten, set_data) \
1344
+ _(aten, sgn) \
1345
+ _(aten, sgn_) \
1346
+ _(aten, sigmoid) \
1347
+ _(aten, sigmoid_) \
1348
+ _(aten, sigmoid_backward) \
1349
+ _(aten, sign) \
1350
+ _(aten, sign_) \
1351
+ _(aten, signbit) \
1352
+ _(aten, silu) \
1353
+ _(aten, silu_) \
1354
+ _(aten, silu_backward) \
1355
+ _(aten, sin) \
1356
+ _(aten, sin_) \
1357
+ _(aten, sinc) \
1358
+ _(aten, sinc_) \
1359
+ _(aten, sinh) \
1360
+ _(aten, sinh_) \
1361
+ _(aten, size) \
1362
+ _(aten, slice) \
1363
+ _(aten, slice_backward) \
1364
+ _(aten, slice_copy) \
1365
+ _(aten, slice_inverse) \
1366
+ _(aten, slice_scatter) \
1367
+ _(aten, slogdet) \
1368
+ _(aten, slow_conv3d) \
1369
+ _(aten, slow_conv3d_forward) \
1370
+ _(aten, slow_conv_dilated2d) \
1371
+ _(aten, slow_conv_dilated3d) \
1372
+ _(aten, slow_conv_transpose2d) \
1373
+ _(aten, slow_conv_transpose3d) \
1374
+ _(aten, smm) \
1375
+ _(aten, smooth_l1_loss) \
1376
+ _(aten, smooth_l1_loss_backward) \
1377
+ _(aten, soft_margin_loss) \
1378
+ _(aten, soft_margin_loss_backward) \
1379
+ _(aten, softmax) \
1380
+ _(aten, softplus) \
1381
+ _(aten, softplus_backward) \
1382
+ _(aten, softshrink) \
1383
+ _(aten, softshrink_backward) \
1384
+ _(aten, sort) \
1385
+ _(aten, sparse_bsc_tensor) \
1386
+ _(aten, sparse_bsr_tensor) \
1387
+ _(aten, sparse_compressed_tensor) \
1388
+ _(aten, sparse_coo_tensor) \
1389
+ _(aten, sparse_csc_tensor) \
1390
+ _(aten, sparse_csr_tensor) \
1391
+ _(aten, sparse_dim) \
1392
+ _(aten, sparse_mask) \
1393
+ _(aten, sparse_resize) \
1394
+ _(aten, sparse_resize_) \
1395
+ _(aten, sparse_resize_and_clear) \
1396
+ _(aten, sparse_resize_and_clear_) \
1397
+ _(aten, sparse_sampled_addmm) \
1398
+ _(aten, special_airy_ai) \
1399
+ _(aten, special_bessel_j0) \
1400
+ _(aten, special_bessel_j1) \
1401
+ _(aten, special_bessel_y0) \
1402
+ _(aten, special_bessel_y1) \
1403
+ _(aten, special_chebyshev_polynomial_t) \
1404
+ _(aten, special_chebyshev_polynomial_u) \
1405
+ _(aten, special_chebyshev_polynomial_v) \
1406
+ _(aten, special_chebyshev_polynomial_w) \
1407
+ _(aten, special_digamma) \
1408
+ _(aten, special_entr) \
1409
+ _(aten, special_erf) \
1410
+ _(aten, special_erfc) \
1411
+ _(aten, special_erfcx) \
1412
+ _(aten, special_erfinv) \
1413
+ _(aten, special_exp2) \
1414
+ _(aten, special_expit) \
1415
+ _(aten, special_expm1) \
1416
+ _(aten, special_gammainc) \
1417
+ _(aten, special_gammaincc) \
1418
+ _(aten, special_gammaln) \
1419
+ _(aten, special_hermite_polynomial_h) \
1420
+ _(aten, special_hermite_polynomial_he) \
1421
+ _(aten, special_i0) \
1422
+ _(aten, special_i0e) \
1423
+ _(aten, special_i1) \
1424
+ _(aten, special_i1e) \
1425
+ _(aten, special_laguerre_polynomial_l) \
1426
+ _(aten, special_legendre_polynomial_p) \
1427
+ _(aten, special_log1p) \
1428
+ _(aten, special_log_ndtr) \
1429
+ _(aten, special_log_softmax) \
1430
+ _(aten, special_logit) \
1431
+ _(aten, special_logsumexp) \
1432
+ _(aten, special_modified_bessel_i0) \
1433
+ _(aten, special_modified_bessel_i1) \
1434
+ _(aten, special_modified_bessel_k0) \
1435
+ _(aten, special_modified_bessel_k1) \
1436
+ _(aten, special_multigammaln) \
1437
+ _(aten, special_ndtr) \
1438
+ _(aten, special_ndtri) \
1439
+ _(aten, special_polygamma) \
1440
+ _(aten, special_psi) \
1441
+ _(aten, special_round) \
1442
+ _(aten, special_scaled_modified_bessel_k0) \
1443
+ _(aten, special_scaled_modified_bessel_k1) \
1444
+ _(aten, special_shifted_chebyshev_polynomial_t) \
1445
+ _(aten, special_shifted_chebyshev_polynomial_u) \
1446
+ _(aten, special_shifted_chebyshev_polynomial_v) \
1447
+ _(aten, special_shifted_chebyshev_polynomial_w) \
1448
+ _(aten, special_sinc) \
1449
+ _(aten, special_softmax) \
1450
+ _(aten, special_spherical_bessel_j0) \
1451
+ _(aten, special_xlog1py) \
1452
+ _(aten, special_xlogy) \
1453
+ _(aten, special_zeta) \
1454
+ _(aten, split) \
1455
+ _(aten, split_copy) \
1456
+ _(aten, split_with_sizes) \
1457
+ _(aten, split_with_sizes_copy) \
1458
+ _(aten, sqrt) \
1459
+ _(aten, sqrt_) \
1460
+ _(aten, square) \
1461
+ _(aten, square_) \
1462
+ _(aten, squeeze) \
1463
+ _(aten, squeeze_) \
1464
+ _(aten, squeeze_copy) \
1465
+ _(aten, sspaddmm) \
1466
+ _(aten, stack) \
1467
+ _(aten, std) \
1468
+ _(aten, std_mean) \
1469
+ _(aten, stft) \
1470
+ _(aten, stride) \
1471
+ _(aten, sub) \
1472
+ _(aten, sub_) \
1473
+ _(aten, subtract) \
1474
+ _(aten, subtract_) \
1475
+ _(aten, sum) \
1476
+ _(aten, sum_to_size) \
1477
+ _(aten, svd) \
1478
+ _(aten, swapaxes) \
1479
+ _(aten, swapaxes_) \
1480
+ _(aten, swapdims) \
1481
+ _(aten, swapdims_) \
1482
+ _(aten, sym_constrain_range) \
1483
+ _(aten, sym_constrain_range_for_size) \
1484
+ _(aten, sym_numel) \
1485
+ _(aten, sym_size) \
1486
+ _(aten, sym_storage_offset) \
1487
+ _(aten, sym_stride) \
1488
+ _(aten, t) \
1489
+ _(aten, t_) \
1490
+ _(aten, t_copy) \
1491
+ _(aten, take) \
1492
+ _(aten, take_along_dim) \
1493
+ _(aten, tan) \
1494
+ _(aten, tan_) \
1495
+ _(aten, tanh) \
1496
+ _(aten, tanh_) \
1497
+ _(aten, tanh_backward) \
1498
+ _(aten, tensor_split) \
1499
+ _(aten, tensordot) \
1500
+ _(aten, thnn_conv2d) \
1501
+ _(aten, threshold) \
1502
+ _(aten, threshold_) \
1503
+ _(aten, threshold_backward) \
1504
+ _(aten, tile) \
1505
+ _(aten, to) \
1506
+ _(aten, to_dense) \
1507
+ _(aten, to_dense_backward) \
1508
+ _(aten, to_mkldnn) \
1509
+ _(aten, to_mkldnn_backward) \
1510
+ _(aten, to_padded_tensor) \
1511
+ _(aten, to_sparse) \
1512
+ _(aten, to_sparse_bsc) \
1513
+ _(aten, to_sparse_bsr) \
1514
+ _(aten, to_sparse_csc) \
1515
+ _(aten, to_sparse_csr) \
1516
+ _(aten, topk) \
1517
+ _(aten, trace) \
1518
+ _(aten, trace_backward) \
1519
+ _(aten, transpose) \
1520
+ _(aten, transpose_) \
1521
+ _(aten, transpose_copy) \
1522
+ _(aten, trapezoid) \
1523
+ _(aten, trapz) \
1524
+ _(aten, triangular_solve) \
1525
+ _(aten, tril) \
1526
+ _(aten, tril_) \
1527
+ _(aten, tril_indices) \
1528
+ _(aten, triplet_margin_loss) \
1529
+ _(aten, triu) \
1530
+ _(aten, triu_) \
1531
+ _(aten, triu_indices) \
1532
+ _(aten, true_divide) \
1533
+ _(aten, true_divide_) \
1534
+ _(aten, trunc) \
1535
+ _(aten, trunc_) \
1536
+ _(aten, type_as) \
1537
+ _(aten, unbind) \
1538
+ _(aten, unbind_copy) \
1539
+ _(aten, unflatten) \
1540
+ _(aten, unflatten_dense_tensors) \
1541
+ _(aten, unfold) \
1542
+ _(aten, unfold_backward) \
1543
+ _(aten, unfold_copy) \
1544
+ _(aten, uniform) \
1545
+ _(aten, uniform_) \
1546
+ _(aten, unique_consecutive) \
1547
+ _(aten, unique_dim) \
1548
+ _(aten, unique_dim_consecutive) \
1549
+ _(aten, unsafe_chunk) \
1550
+ _(aten, unsafe_split) \
1551
+ _(aten, unsafe_split_with_sizes) \
1552
+ _(aten, unsqueeze) \
1553
+ _(aten, unsqueeze_) \
1554
+ _(aten, unsqueeze_copy) \
1555
+ _(aten, upsample_bicubic2d) \
1556
+ _(aten, upsample_bicubic2d_backward) \
1557
+ _(aten, upsample_bilinear2d) \
1558
+ _(aten, upsample_bilinear2d_backward) \
1559
+ _(aten, upsample_linear1d) \
1560
+ _(aten, upsample_linear1d_backward) \
1561
+ _(aten, upsample_nearest1d) \
1562
+ _(aten, upsample_nearest1d_backward) \
1563
+ _(aten, upsample_nearest2d) \
1564
+ _(aten, upsample_nearest2d_backward) \
1565
+ _(aten, upsample_nearest3d) \
1566
+ _(aten, upsample_nearest3d_backward) \
1567
+ _(aten, upsample_trilinear3d) \
1568
+ _(aten, upsample_trilinear3d_backward) \
1569
+ _(aten, value_selecting_reduction_backward) \
1570
+ _(aten, values) \
1571
+ _(aten, values_copy) \
1572
+ _(aten, vander) \
1573
+ _(aten, var) \
1574
+ _(aten, var_mean) \
1575
+ _(aten, vdot) \
1576
+ _(aten, view) \
1577
+ _(aten, view_as) \
1578
+ _(aten, view_as_complex) \
1579
+ _(aten, view_as_complex_copy) \
1580
+ _(aten, view_as_real) \
1581
+ _(aten, view_as_real_copy) \
1582
+ _(aten, view_copy) \
1583
+ _(aten, vsplit) \
1584
+ _(aten, vstack) \
1585
+ _(aten, where) \
1586
+ _(aten, xlogy) \
1587
+ _(aten, xlogy_) \
1588
+ _(aten, zero) \
1589
+ _(aten, zero_) \
1590
+ _(aten, zeros) \
1591
+ _(aten, zeros_like)
1592
+
1593
+ #define FORALL_ATTR_BASE_SYMBOLS(_) \
1594
+ _(attr, A) \
1595
+ _(attr, B) \
1596
+ _(attr, C) \
1597
+ _(attr, H) \
1598
+ _(attr, HxW) \
1599
+ _(attr, K) \
1600
+ _(attr, L) \
1601
+ _(attr, LD) \
1602
+ _(attr, LU) \
1603
+ _(attr, LU_data) \
1604
+ _(attr, LU_pivots) \
1605
+ _(attr, M) \
1606
+ _(attr, N) \
1607
+ _(attr, P) \
1608
+ _(attr, Q) \
1609
+ _(attr, R) \
1610
+ _(attr, S) \
1611
+ _(attr, U) \
1612
+ _(attr, UPLO) \
1613
+ _(attr, V) \
1614
+ _(attr, Vh) \
1615
+ _(attr, W) \
1616
+ _(attr, X) \
1617
+ _(attr, a) \
1618
+ _(attr, abs) \
1619
+ _(attr, accumulate) \
1620
+ _(attr, accumulate_matches) \
1621
+ _(attr, activation) \
1622
+ _(attr, addends) \
1623
+ _(attr, adjoint) \
1624
+ _(attr, alg_id) \
1625
+ _(attr, algorithm) \
1626
+ _(attr, alibi_slopes) \
1627
+ _(attr, align_corners) \
1628
+ _(attr, align_to_window) \
1629
+ _(attr, allow_tf32) \
1630
+ _(attr, alpha) \
1631
+ _(attr, amsgrad) \
1632
+ _(attr, anchor) \
1633
+ _(attr, angle) \
1634
+ _(attr, any) \
1635
+ _(attr, api_name) \
1636
+ _(attr, append) \
1637
+ _(attr, approximate) \
1638
+ _(attr, arg1) \
1639
+ _(attr, arg2) \
1640
+ _(attr, arg3) \
1641
+ _(attr, arg_out) \
1642
+ _(attr, assert_msg) \
1643
+ _(attr, assume_unique) \
1644
+ _(attr, atol) \
1645
+ _(attr, attn_bias) \
1646
+ _(attr, attn_mask) \
1647
+ _(attr, average_attn_weights) \
1648
+ _(attr, averaging_const) \
1649
+ _(attr, aweights) \
1650
+ _(attr, axis) \
1651
+ _(attr, axis0) \
1652
+ _(attr, axis1) \
1653
+ _(attr, b) \
1654
+ _(attr, b_hh) \
1655
+ _(attr, b_ih) \
1656
+ _(attr, bag_size) \
1657
+ _(attr, base) \
1658
+ _(attr, batch1) \
1659
+ _(attr, batch2) \
1660
+ _(attr, batch_dim) \
1661
+ _(attr, batch_first) \
1662
+ _(attr, batch_size) \
1663
+ _(attr, batch_sizes) \
1664
+ _(attr, benchmark) \
1665
+ _(attr, beta) \
1666
+ _(attr, beta1) \
1667
+ _(attr, beta2) \
1668
+ _(attr, bias) \
1669
+ _(attr, bias_defined) \
1670
+ _(attr, bias_g) \
1671
+ _(attr, bias_requires_grad) \
1672
+ _(attr, bias_sizes) \
1673
+ _(attr, bidirectional) \
1674
+ _(attr, bin_edges) \
1675
+ _(attr, bins) \
1676
+ _(attr, bit_width) \
1677
+ _(attr, blank) \
1678
+ _(attr, block_size) \
1679
+ _(attr, blocksize) \
1680
+ _(attr, boundaries) \
1681
+ _(attr, buffer) \
1682
+ _(attr, ccol_indices) \
1683
+ _(attr, cdim) \
1684
+ _(attr, cdist) \
1685
+ _(attr, ceil_mode) \
1686
+ _(attr, cell_state_fwd) \
1687
+ _(attr, center) \
1688
+ _(attr, ch_axis) \
1689
+ _(attr, check_errors) \
1690
+ _(attr, check_pinning) \
1691
+ _(attr, chunks) \
1692
+ _(attr, coalesced) \
1693
+ _(attr, coefficients) \
1694
+ _(attr, col) \
1695
+ _(attr, col_indices) \
1696
+ _(attr, col_offsets) \
1697
+ _(attr, col_offsets_hh) \
1698
+ _(attr, col_offsets_ih) \
1699
+ _(attr, compressed_A) \
1700
+ _(attr, compressed_idx) \
1701
+ _(attr, compressed_indices) \
1702
+ _(attr, compressed_indices_dtype) \
1703
+ _(attr, compute_log_sumexp) \
1704
+ _(attr, compute_mode) \
1705
+ _(attr, compute_uv) \
1706
+ _(attr, compute_v) \
1707
+ _(attr, condition) \
1708
+ _(attr, copy) \
1709
+ _(attr, correction) \
1710
+ _(attr, count) \
1711
+ _(attr, count_include_pad) \
1712
+ _(attr, counts) \
1713
+ _(attr, cpu_dtype) \
1714
+ _(attr, cpu_enabled) \
1715
+ _(attr, cpu_nested_shape_example) \
1716
+ _(attr, create_graph) \
1717
+ _(attr, crow_indices) \
1718
+ _(attr, cu_seqlens_k) \
1719
+ _(attr, cu_seqlens_q) \
1720
+ _(attr, cuda_dtype) \
1721
+ _(attr, cuda_enabled) \
1722
+ _(attr, cudnn_enable) \
1723
+ _(attr, cudnn_enabled) \
1724
+ _(attr, cum_seq_k) \
1725
+ _(attr, cum_seq_q) \
1726
+ _(attr, custom_mask_type) \
1727
+ _(attr, cx) \
1728
+ _(attr, cx_) \
1729
+ _(attr, cx_tmp) \
1730
+ _(attr, cy) \
1731
+ _(attr, cy_) \
1732
+ _(attr, d) \
1733
+ _(attr, dampening) \
1734
+ _(attr, data) \
1735
+ _(attr, decimals) \
1736
+ _(attr, delta) \
1737
+ _(attr, dense) \
1738
+ _(attr, dense_B) \
1739
+ _(attr, dense_dim) \
1740
+ _(attr, density) \
1741
+ _(attr, dep_token) \
1742
+ _(attr, descending) \
1743
+ _(attr, destination) \
1744
+ _(attr, deterministic) \
1745
+ _(attr, device) \
1746
+ _(attr, device_index) \
1747
+ _(attr, dgrad_glu) \
1748
+ _(attr, diagonal) \
1749
+ _(attr, diagonals) \
1750
+ _(attr, dilation) \
1751
+ _(attr, dim) \
1752
+ _(attr, dim0) \
1753
+ _(attr, dim1) \
1754
+ _(attr, dim2) \
1755
+ _(attr, dimension) \
1756
+ _(attr, dims) \
1757
+ _(attr, dims_other) \
1758
+ _(attr, dims_self) \
1759
+ _(attr, divisor_override) \
1760
+ _(attr, downscale_factor) \
1761
+ _(attr, driver) \
1762
+ _(attr, dropout) \
1763
+ _(attr, dropout_mask) \
1764
+ _(attr, dropout_p) \
1765
+ _(attr, dropout_seed) \
1766
+ _(attr, dropout_state) \
1767
+ _(attr, dst) \
1768
+ _(attr, dtype) \
1769
+ _(attr, dual) \
1770
+ _(attr, dummy) \
1771
+ _(attr, dx) \
1772
+ _(attr, edge_order) \
1773
+ _(attr, eigenvalues) \
1774
+ _(attr, eigenvectors) \
1775
+ _(attr, eigvals) \
1776
+ _(attr, eigvecs) \
1777
+ _(attr, element) \
1778
+ _(attr, elements) \
1779
+ _(attr, ellipsis_idx) \
1780
+ _(attr, embed_dim) \
1781
+ _(attr, enable_gqa) \
1782
+ _(attr, end) \
1783
+ _(attr, end_dim) \
1784
+ _(attr, eps) \
1785
+ _(attr, epsilon) \
1786
+ _(attr, equal_nan) \
1787
+ _(attr, equation) \
1788
+ _(attr, exp_avg_sqs) \
1789
+ _(attr, exp_avgs) \
1790
+ _(attr, expand1) \
1791
+ _(attr, expand2) \
1792
+ _(attr, expand3) \
1793
+ _(attr, exponent) \
1794
+ _(attr, exponential_average_factor) \
1795
+ _(attr, fake_quant_enabled) \
1796
+ _(attr, fake_quant_on) \
1797
+ _(attr, ffn_bias_1) \
1798
+ _(attr, ffn_bias_2) \
1799
+ _(attr, ffn_weight_1) \
1800
+ _(attr, ffn_weight_2) \
1801
+ _(attr, filename) \
1802
+ _(attr, fill) \
1803
+ _(attr, fill_value) \
1804
+ _(attr, flat) \
1805
+ _(attr, forward) \
1806
+ _(attr, found_inf) \
1807
+ _(attr, from) \
1808
+ _(attr, from_) \
1809
+ _(attr, full) \
1810
+ _(attr, full_matrices) \
1811
+ _(attr, fuse_transform_0213) \
1812
+ _(attr, fweights) \
1813
+ _(attr, g) \
1814
+ _(attr, gO) \
1815
+ _(attr, generator) \
1816
+ _(attr, ggI) \
1817
+ _(attr, ggW) \
1818
+ _(attr, ggb) \
1819
+ _(attr, glu) \
1820
+ _(attr, grad) \
1821
+ _(attr, grad_bias) \
1822
+ _(attr, grad_cy) \
1823
+ _(attr, grad_factor) \
1824
+ _(attr, grad_glu) \
1825
+ _(attr, grad_hy) \
1826
+ _(attr, grad_in) \
1827
+ _(attr, grad_input) \
1828
+ _(attr, grad_input_mask) \
1829
+ _(attr, grad_out) \
1830
+ _(attr, grad_out_) \
1831
+ _(attr, grad_output) \
1832
+ _(attr, grad_scale) \
1833
+ _(attr, grad_w) \
1834
+ _(attr, grad_weight) \
1835
+ _(attr, grad_x) \
1836
+ _(attr, grad_y) \
1837
+ _(attr, gradient) \
1838
+ _(attr, grads) \
1839
+ _(attr, grid) \
1840
+ _(attr, group) \
1841
+ _(attr, groups) \
1842
+ _(attr, growth_interval) \
1843
+ _(attr, growth_tracker) \
1844
+ _(attr, half_to_float) \
1845
+ _(attr, has_bias) \
1846
+ _(attr, has_biases) \
1847
+ _(attr, hermitian) \
1848
+ _(attr, hidden_bias) \
1849
+ _(attr, hidden_gates) \
1850
+ _(attr, hidden_size) \
1851
+ _(attr, high) \
1852
+ _(attr, hist) \
1853
+ _(attr, hop_length) \
1854
+ _(attr, hx) \
1855
+ _(attr, hx_) \
1856
+ _(attr, hy_) \
1857
+ _(attr, i1) \
1858
+ _(attr, i2) \
1859
+ _(attr, i3) \
1860
+ _(attr, ignore_index) \
1861
+ _(attr, imag) \
1862
+ _(attr, impl_index) \
1863
+ _(attr, implicit) \
1864
+ _(attr, in_features) \
1865
+ _(attr, include_last_offset) \
1866
+ _(attr, include_self) \
1867
+ _(attr, increasing) \
1868
+ _(attr, ind) \
1869
+ _(attr, index) \
1870
+ _(attr, index_dtype) \
1871
+ _(attr, indexing) \
1872
+ _(attr, indices) \
1873
+ _(attr, info) \
1874
+ _(attr, initial) \
1875
+ _(attr, innerKTiles) \
1876
+ _(attr, inp) \
1877
+ _(attr, input) \
1878
+ _(attr, input1) \
1879
+ _(attr, input2) \
1880
+ _(attr, input3) \
1881
+ _(attr, input_bias) \
1882
+ _(attr, input_dtype) \
1883
+ _(attr, input_g) \
1884
+ _(attr, input_gates) \
1885
+ _(attr, input_lengths) \
1886
+ _(attr, input_scale) \
1887
+ _(attr, input_size) \
1888
+ _(attr, input_sizes) \
1889
+ _(attr, input_zero_point) \
1890
+ _(attr, inputs) \
1891
+ _(attr, interpolation) \
1892
+ _(attr, interpolation_mode) \
1893
+ _(attr, inv_scale) \
1894
+ _(attr, inverse) \
1895
+ _(attr, invert) \
1896
+ _(attr, invstd) \
1897
+ _(attr, is_causal) \
1898
+ _(attr, is_coalesced) \
1899
+ _(attr, is_crow) \
1900
+ _(attr, is_first_step) \
1901
+ _(attr, is_matrix) \
1902
+ _(attr, is_result) \
1903
+ _(attr, is_target) \
1904
+ _(attr, k) \
1905
+ _(attr, keepdim) \
1906
+ _(attr, kernel_size) \
1907
+ _(attr, key) \
1908
+ _(attr, label_smoothing) \
1909
+ _(attr, lambd) \
1910
+ _(attr, largest) \
1911
+ _(attr, last_dim_size) \
1912
+ _(attr, layersOutputs) \
1913
+ _(attr, layout) \
1914
+ _(attr, left) \
1915
+ _(attr, length) \
1916
+ _(attr, lengths) \
1917
+ _(attr, level) \
1918
+ _(attr, like) \
1919
+ _(attr, list) \
1920
+ _(attr, log_alpha) \
1921
+ _(attr, log_input) \
1922
+ _(attr, log_probs) \
1923
+ _(attr, log_target) \
1924
+ _(attr, logabsdet) \
1925
+ _(attr, logsumexp) \
1926
+ _(attr, low) \
1927
+ _(attr, lower) \
1928
+ _(attr, lr) \
1929
+ _(attr, lr_decay) \
1930
+ _(attr, ltm) \
1931
+ _(attr, m) \
1932
+ _(attr, mantissa) \
1933
+ _(attr, margin) \
1934
+ _(attr, mask) \
1935
+ _(attr, mask_check) \
1936
+ _(attr, mask_type) \
1937
+ _(attr, masked_grad) \
1938
+ _(attr, mat) \
1939
+ _(attr, mat1) \
1940
+ _(attr, mat1_meta) \
1941
+ _(attr, mat2) \
1942
+ _(attr, matrices) \
1943
+ _(attr, max) \
1944
+ _(attr, max_exp_avg_sqs) \
1945
+ _(attr, max_k) \
1946
+ _(attr, max_lengths) \
1947
+ _(attr, max_norm) \
1948
+ _(attr, max_q) \
1949
+ _(attr, max_seqlen) \
1950
+ _(attr, max_seqlen_k) \
1951
+ _(attr, max_seqlen_q) \
1952
+ _(attr, max_size) \
1953
+ _(attr, max_val) \
1954
+ _(attr, max_values) \
1955
+ _(attr, maximize) \
1956
+ _(attr, maximum_indices) \
1957
+ _(attr, maxnorm) \
1958
+ _(attr, mean) \
1959
+ _(attr, median) \
1960
+ _(attr, memory_format) \
1961
+ _(attr, meta) \
1962
+ _(attr, min) \
1963
+ _(attr, min_indices) \
1964
+ _(attr, min_seqlen) \
1965
+ _(attr, min_val) \
1966
+ _(attr, minlength) \
1967
+ _(attr, mode) \
1968
+ _(attr, momentum) \
1969
+ _(attr, momentum_buffer_list) \
1970
+ _(attr, n) \
1971
+ _(attr, n_bins) \
1972
+ _(attr, n_fft) \
1973
+ _(attr, names) \
1974
+ _(attr, nan) \
1975
+ _(attr, need_weights) \
1976
+ _(attr, neg_log_likelihood) \
1977
+ _(attr, negative) \
1978
+ _(attr, negative_slope) \
1979
+ _(attr, neginf) \
1980
+ _(attr, nested_size) \
1981
+ _(attr, nested_strides) \
1982
+ _(attr, nesterov) \
1983
+ _(attr, new_data) \
1984
+ _(attr, nnz) \
1985
+ _(attr, noise) \
1986
+ _(attr, non_blocking) \
1987
+ _(attr, norm) \
1988
+ _(attr, norm_bias_1) \
1989
+ _(attr, norm_bias_2) \
1990
+ _(attr, norm_first) \
1991
+ _(attr, norm_type) \
1992
+ _(attr, norm_weight_1) \
1993
+ _(attr, norm_weight_2) \
1994
+ _(attr, normalization) \
1995
+ _(attr, normalized) \
1996
+ _(attr, normalized_shape) \
1997
+ _(attr, normalized_shape_ndim) \
1998
+ _(attr, nt_example) \
1999
+ _(attr, num_chunks) \
2000
+ _(attr, num_classes) \
2001
+ _(attr, num_generated) \
2002
+ _(attr, num_groups) \
2003
+ _(attr, num_head) \
2004
+ _(attr, num_heads) \
2005
+ _(attr, num_layers) \
2006
+ _(attr, num_parallel) \
2007
+ _(attr, num_samples) \
2008
+ _(attr, num_splits_key) \
2009
+ _(attr, num_weights) \
2010
+ _(attr, numel) \
2011
+ _(attr, observer_on) \
2012
+ _(attr, offs) \
2013
+ _(attr, offset) \
2014
+ _(attr, offset2bag) \
2015
+ _(attr, offsets) \
2016
+ _(attr, onesided) \
2017
+ _(attr, ord) \
2018
+ _(attr, order) \
2019
+ _(attr, other) \
2020
+ _(attr, out) \
2021
+ _(attr, out0) \
2022
+ _(attr, out1) \
2023
+ _(attr, out2) \
2024
+ _(attr, out3) \
2025
+ _(attr, out4) \
2026
+ _(attr, out5) \
2027
+ _(attr, out6) \
2028
+ _(attr, out_channel) \
2029
+ _(attr, out_dim) \
2030
+ _(attr, out_dtype) \
2031
+ _(attr, out_features) \
2032
+ _(attr, out_int32) \
2033
+ _(attr, outdim) \
2034
+ _(attr, output) \
2035
+ _(attr, output_mask) \
2036
+ _(attr, output_padding) \
2037
+ _(attr, output_scale) \
2038
+ _(attr, output_size) \
2039
+ _(attr, output_zero_point) \
2040
+ _(attr, p) \
2041
+ _(attr, packed) \
2042
+ _(attr, packed_hh) \
2043
+ _(attr, packed_ih) \
2044
+ _(attr, packed_weight) \
2045
+ _(attr, packed_weights) \
2046
+ _(attr, pad) \
2047
+ _(attr, pad_mode) \
2048
+ _(attr, padded) \
2049
+ _(attr, padding) \
2050
+ _(attr, padding_idx) \
2051
+ _(attr, padding_mode) \
2052
+ _(attr, padding_side) \
2053
+ _(attr, padding_value) \
2054
+ _(attr, params) \
2055
+ _(attr, path) \
2056
+ _(attr, pdist) \
2057
+ _(attr, per_row_fake_quant) \
2058
+ _(attr, per_sample_weights) \
2059
+ _(attr, periodic) \
2060
+ _(attr, philox_offset) \
2061
+ _(attr, philox_seed) \
2062
+ _(attr, physical_layout) \
2063
+ _(attr, pin_memory) \
2064
+ _(attr, pivot) \
2065
+ _(attr, pivots) \
2066
+ _(attr, plain_idx) \
2067
+ _(attr, plain_indices) \
2068
+ _(attr, pos_weight) \
2069
+ _(attr, posinf) \
2070
+ _(attr, positive) \
2071
+ _(attr, pow) \
2072
+ _(attr, prepend) \
2073
+ _(attr, primal) \
2074
+ _(attr, prob) \
2075
+ _(attr, proj_bias) \
2076
+ _(attr, proj_size) \
2077
+ _(attr, proj_weight) \
2078
+ _(attr, q) \
2079
+ _(attr, qGroupSize) \
2080
+ _(attr, qScale) \
2081
+ _(attr, qScaleAndZeros) \
2082
+ _(attr, qZeros) \
2083
+ _(attr, qkv) \
2084
+ _(attr, qkv_bias) \
2085
+ _(attr, qkv_weight) \
2086
+ _(attr, qtensor) \
2087
+ _(attr, quant_max) \
2088
+ _(attr, quant_min) \
2089
+ _(attr, quasi) \
2090
+ _(attr, query) \
2091
+ _(attr, r) \
2092
+ _(attr, ragged_idx) \
2093
+ _(attr, random_samples) \
2094
+ _(attr, range) \
2095
+ _(attr, rank) \
2096
+ _(attr, ratio) \
2097
+ _(attr, rcond) \
2098
+ _(attr, real) \
2099
+ _(attr, reduce) \
2100
+ _(attr, reduce_range) \
2101
+ _(attr, reduction) \
2102
+ _(attr, repeats) \
2103
+ _(attr, replacement) \
2104
+ _(attr, requires_grad) \
2105
+ _(attr, reserve) \
2106
+ _(attr, reserveSpace) \
2107
+ _(attr, reservedSpace) \
2108
+ _(attr, residuals) \
2109
+ _(attr, result) \
2110
+ _(attr, retain_graph) \
2111
+ _(attr, return_complex) \
2112
+ _(attr, return_counts) \
2113
+ _(attr, return_debug_mask) \
2114
+ _(attr, return_inverse) \
2115
+ _(attr, reverse) \
2116
+ _(attr, right) \
2117
+ _(attr, rng_state) \
2118
+ _(attr, rounding_mode) \
2119
+ _(attr, row) \
2120
+ _(attr, row_indices) \
2121
+ _(attr, rstd) \
2122
+ _(attr, rtol) \
2123
+ _(attr, running_max) \
2124
+ _(attr, running_mean) \
2125
+ _(attr, running_min) \
2126
+ _(attr, running_var) \
2127
+ _(attr, s) \
2128
+ _(attr, save_invstd) \
2129
+ _(attr, save_mean) \
2130
+ _(attr, save_var) \
2131
+ _(attr, save_var_transform) \
2132
+ _(attr, saved_g) \
2133
+ _(attr, saved_norms) \
2134
+ _(attr, saved_v) \
2135
+ _(attr, scalar) \
2136
+ _(attr, scalar1) \
2137
+ _(attr, scalar2) \
2138
+ _(attr, scalars) \
2139
+ _(attr, scale) \
2140
+ _(attr, scale_a) \
2141
+ _(attr, scale_b) \
2142
+ _(attr, scale_backoff_factor) \
2143
+ _(attr, scale_factors) \
2144
+ _(attr, scale_grad_by_freq) \
2145
+ _(attr, scale_growth_factor) \
2146
+ _(attr, scale_hh) \
2147
+ _(attr, scale_ih) \
2148
+ _(attr, scale_result) \
2149
+ _(attr, scales) \
2150
+ _(attr, scales_d) \
2151
+ _(attr, scales_h) \
2152
+ _(attr, scales_w) \
2153
+ _(attr, scales_zeros) \
2154
+ _(attr, sections) \
2155
+ _(attr, seed) \
2156
+ _(attr, self) \
2157
+ _(attr, self_is_result) \
2158
+ _(attr, self_num_batch_dims) \
2159
+ _(attr, self_or_result) \
2160
+ _(attr, self_sizes) \
2161
+ _(attr, seqlen_k) \
2162
+ _(attr, sequences) \
2163
+ _(attr, seqused_k) \
2164
+ _(attr, shape) \
2165
+ _(attr, shared) \
2166
+ _(attr, shared_storage_dqdkdv) \
2167
+ _(attr, shifts) \
2168
+ _(attr, side) \
2169
+ _(attr, sigma) \
2170
+ _(attr, sign) \
2171
+ _(attr, singular_values) \
2172
+ _(attr, size) \
2173
+ _(attr, sizes) \
2174
+ _(attr, skip_first) \
2175
+ _(attr, sobolstate) \
2176
+ _(attr, solution) \
2177
+ _(attr, some) \
2178
+ _(attr, sorted) \
2179
+ _(attr, sorted_sequence) \
2180
+ _(attr, sorter) \
2181
+ _(attr, source) \
2182
+ _(attr, spacing) \
2183
+ _(attr, sparse) \
2184
+ _(attr, sparse_dim) \
2185
+ _(attr, sparse_grad) \
2186
+ _(attr, split_k) \
2187
+ _(attr, split_k_mode) \
2188
+ _(attr, split_size) \
2189
+ _(attr, split_sizes) \
2190
+ _(attr, src) \
2191
+ _(attr, stable) \
2192
+ _(attr, start) \
2193
+ _(attr, start_dim) \
2194
+ _(attr, state_steps) \
2195
+ _(attr, state_sums) \
2196
+ _(attr, std) \
2197
+ _(attr, step) \
2198
+ _(attr, steps) \
2199
+ _(attr, storage_offset) \
2200
+ _(attr, stride) \
2201
+ _(attr, sum_S) \
2202
+ _(attr, sum_dy) \
2203
+ _(attr, sum_dy_xmu) \
2204
+ _(attr, sumdim) \
2205
+ _(attr, swap) \
2206
+ _(attr, symmetric_quant) \
2207
+ _(attr, t) \
2208
+ _(attr, tangent) \
2209
+ _(attr, target) \
2210
+ _(attr, target_lengths) \
2211
+ _(attr, targets) \
2212
+ _(attr, tau) \
2213
+ _(attr, tensor) \
2214
+ _(attr, tensor1) \
2215
+ _(attr, tensor2) \
2216
+ _(attr, tensor_indices_or_sections) \
2217
+ _(attr, tensors) \
2218
+ _(attr, tensors1) \
2219
+ _(attr, test_element) \
2220
+ _(attr, test_elements) \
2221
+ _(attr, the_template) \
2222
+ _(attr, theta) \
2223
+ _(attr, thread_masks) \
2224
+ _(attr, threshold) \
2225
+ _(attr, to) \
2226
+ _(attr, tol) \
2227
+ _(attr, total) \
2228
+ _(attr, total_L) \
2229
+ _(attr, total_length) \
2230
+ _(attr, total_weight) \
2231
+ _(attr, train) \
2232
+ _(attr, training) \
2233
+ _(attr, transpose) \
2234
+ _(attr, transpose_result) \
2235
+ _(attr, transposed) \
2236
+ _(attr, type1) \
2237
+ _(attr, type2) \
2238
+ _(attr, unbiased) \
2239
+ _(attr, unitriangular) \
2240
+ _(attr, unpack_data) \
2241
+ _(attr, unpack_pivots) \
2242
+ _(attr, unroll_dim) \
2243
+ _(attr, unsafe) \
2244
+ _(attr, unused) \
2245
+ _(attr, update) \
2246
+ _(attr, upper) \
2247
+ _(attr, upscale_factor) \
2248
+ _(attr, use_cutlass) \
2249
+ _(attr, use_fast_accum) \
2250
+ _(attr, use_gelu) \
2251
+ _(attr, use_input_stats) \
2252
+ _(attr, v) \
2253
+ _(attr, value) \
2254
+ _(attr, values) \
2255
+ _(attr, var) \
2256
+ _(attr, vec) \
2257
+ _(attr, vec1) \
2258
+ _(attr, vec2) \
2259
+ _(attr, w_hh) \
2260
+ _(attr, w_ih) \
2261
+ _(attr, weight) \
2262
+ _(attr, weight0) \
2263
+ _(attr, weight1) \
2264
+ _(attr, weight2) \
2265
+ _(attr, weight3) \
2266
+ _(attr, weight4) \
2267
+ _(attr, weight_arr) \
2268
+ _(attr, weight_buf) \
2269
+ _(attr, weight_decay) \
2270
+ _(attr, weight_g) \
2271
+ _(attr, weight_scale) \
2272
+ _(attr, weight_stride0) \
2273
+ _(attr, weight_zero_point) \
2274
+ _(attr, weights) \
2275
+ _(attr, win_length) \
2276
+ _(attr, window) \
2277
+ _(attr, window_length) \
2278
+ _(attr, window_size) \
2279
+ _(attr, window_size_left) \
2280
+ _(attr, window_size_right) \
2281
+ _(attr, with_replacement) \
2282
+ _(attr, workspace) \
2283
+ _(attr, wrap) \
2284
+ _(attr, x) \
2285
+ _(attr, x1) \
2286
+ _(attr, x2) \
2287
+ _(attr, y) \
2288
+ _(attr, z) \
2289
+ _(attr, z_state) \
2290
+ _(attr, zero_infinity) \
2291
+ _(attr, zero_point) \
2292
+ _(attr, zero_point_hh) \
2293
+ _(attr, zero_point_ih) \
2294
+ _(attr, zero_points)
.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <type_traits>
4
+
5
+ #include <c10/util/intrusive_ptr.h>
6
+ #include <c10/util/typeid.h>
7
+ #include <c10/macros/Macros.h>
8
+
9
+ namespace caffe2 {
10
+
11
+ class Tensor;
12
+
13
+ /**
14
+ * @brief Blob is a general container that hosts a typed pointer.
15
+ *
16
+ * A Blob hosts a pointer as well as its type, and takes charge of deleting it
17
+ * properly when the blob is deallocated or re-allocated with a new type. A blob
18
+ * could contain anything, although the most common case is to contain a Tensor.
19
+ */
20
+ class TORCH_API Blob final : public c10::intrusive_ptr_target {
21
+ public:
22
+ /**
23
+ * Initializes an empty Blob.
24
+ */
25
+ Blob() noexcept = default;
26
+ ~Blob() override {
27
+ Reset();
28
+ }
29
+
30
+ Blob(Blob&& other) noexcept : Blob() {
31
+ swap(other);
32
+ }
33
+
34
+ Blob& operator=(Blob&& other) noexcept {
35
+ Blob(std::move(other)).swap(*this);
36
+ return *this;
37
+ }
38
+
39
+ /**
40
+ * Checks if the content stored in the blob is of type T.
41
+ */
42
+ template <class T>
43
+ bool IsType() const noexcept {
44
+ return meta_.Match<T>();
45
+ }
46
+
47
+ /**
48
+ * Returns the meta info of the blob.
49
+ */
50
+ const TypeMeta meta() const noexcept {
51
+ return meta_;
52
+ }
53
+
54
+ /**
55
+ * Returns a printable typename of the blob.
56
+ */
57
+ std::string_view TypeName() const noexcept {
58
+ return meta_.name();
59
+ }
60
+
61
+ /**
62
+ * @brief Gets the const reference of the stored object. The code checks if
63
+ * the stored object is of the desired type.
64
+ */
65
+ // TODO(jerryzh): add a Get(c10::DeviceType) function?
66
+ template <class T>
67
+ const T& Get() const {
68
+ TORCH_INTERNAL_ASSERT(
69
+ IsType<T>(),
70
+ "wrong type for the Blob instance. Blob contains ",
71
+ meta_.name(),
72
+ " while caller expects ",
73
+ TypeMeta::TypeName<T>());
74
+ // TODO: after we add Get<Tensor>(c10::DeviceType)
75
+ // and changed all the callsites, we can add
76
+ // a static assert here to enforce T != Tensor
77
+ return *static_cast<const T*>(pointer_);
78
+ }
79
+
80
+ const void* GetRaw() const noexcept {
81
+ return pointer_;
82
+ }
83
+ void* GetRaw() noexcept {
84
+ return pointer_;
85
+ }
86
+
87
+ /**
88
+ * @brief Gets a mutable pointer to the stored object.
89
+ *
90
+ * If the current object is not of the right type, a new object is created
91
+ * and the old object is freed. Note that type T should have a default
92
+ * constructor. Otherwise, create the object yourself first, and use
93
+ * Reset().
94
+ */
95
+ template <class T>
96
+ T* GetMutable() {
97
+ static_assert(
98
+ std::is_default_constructible_v<T>,
99
+ "GetMutable can't be called with non-default-constructible types. "
100
+ "Try using specialized methods");
101
+ if (IsType<T>()) {
102
+ return static_cast<T*>(pointer_);
103
+ } else {
104
+ // TODO Re-enable logging
105
+ // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
106
+ return Reset<T>(new T());
107
+ }
108
+ }
109
+
110
+ template <class T>
111
+ T* GetMutableOrNull() {
112
+ if (IsType<T>()) {
113
+ return static_cast<T*>(pointer_);
114
+ } else {
115
+ return nullptr;
116
+ }
117
+ }
118
+
119
+ /**
120
+ * Sets the underlying object to the allocated one. The Blob then takes over
121
+ * the ownership of the passed in pointer. If there is already an object in
122
+ * the Blob, the old object is freed.
123
+ *
124
+ * This is used when the underlying class T does not have a default ctor, or
125
+ * complex initializations needs to be done outside the blob.
126
+ */
127
+ template <class T>
128
+ T* Reset(T* allocated) {
129
+ free_();
130
+ meta_ = TypeMeta::Make<T>();
131
+ pointer_ = static_cast<void*>(allocated);
132
+ has_ownership_ = true;
133
+ return allocated;
134
+ }
135
+
136
+ /**
137
+ * Sets the underlying object to the allocated one, but does not take over
138
+ * the ownership of the passed in pointer. If there is already an object in
139
+ * the Blob, the old object is freed.
140
+ *
141
+ * Unlike Reset, this does not take over the ownership of the pointer and the
142
+ * caller is responsible for making sure that the lifetime of the allocated
143
+ * blob outlasts the lifetime of any access to this blob, until another Reset
144
+ * call is made or the blob is destructed.
145
+ */
146
+ template <class T>
147
+ std::remove_const_t<T>* ShareExternal(
148
+ std::remove_const_t<T>* allocated) {
149
+ return static_cast<T*>(ShareExternal(
150
+ static_cast<void*>(allocated),
151
+ TypeMeta::Make<std::remove_const_t<T>>()));
152
+ }
153
+
154
+ void* ShareExternal(void* allocated, const TypeMeta meta) {
155
+ free_();
156
+ meta_ = meta;
157
+ pointer_ = allocated;
158
+ has_ownership_ = false;
159
+ return allocated;
160
+ }
161
+
162
+ /**
163
+ * Resets the Blob to an empty one.
164
+ */
165
+ void Reset() {
166
+ free_();
167
+ pointer_ = nullptr;
168
+ meta_ = TypeMeta();
169
+ has_ownership_ = false;
170
+ }
171
+
172
+ /**
173
+ * @brief Swaps the underlying storage of two blobs.
174
+ */
175
+ void swap(Blob& rhs) noexcept {
176
+ using std::swap;
177
+ swap(meta_, rhs.meta_);
178
+ swap(pointer_, rhs.pointer_);
179
+ swap(has_ownership_, rhs.has_ownership_);
180
+ }
181
+
182
+ private:
183
+ void free_() {
184
+ if (has_ownership_ && pointer_ != nullptr) {
185
+ (*meta_.deleteFn())(pointer_);
186
+ }
187
+ }
188
+
189
+ TypeMeta meta_;
190
+ void* pointer_{nullptr};
191
+ bool has_ownership_{false};
192
+
193
+ C10_DISABLE_COPY_AND_ASSIGN(Blob);
194
+ };
195
+
196
+ inline void swap(Blob& lhs, Blob& rhs) noexcept {
197
+ lhs.swap(rhs);
198
+ }
199
+
200
+ inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
201
+ return out << "Blob[" << v.TypeName() << "]";
202
+ }
203
+
204
+ } // namespace caffe2
.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/boxing/OperatorKernel.h>
4
+ #include <c10/core/DispatchKeySet.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace c10 {
8
+
9
+ struct IValue;
10
+ using Stack = std::vector<IValue>;
11
+
12
+ class OperatorHandle;
13
+ class KernelFunction;
14
+
15
+ // This kernel implements the behavior of falling through to the next available
16
+ // registered dispatch key. The implementation of this function is FAST; it is
17
+ // no overhead to fallthrough to the next key. See cpp file for some more
18
+ // implementation notes; notably, this does NOT actually go through the
19
+ // boxing/unboxing codepath.
20
+ TORCH_API void fallthrough_kernel(
21
+ OperatorKernel*,
22
+ const OperatorHandle&,
23
+ DispatchKeySet,
24
+ Stack*);
25
+
26
+ // Note [Ambiguity in AutogradOther kernel]
27
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28
+ // This error-reporting kernel is registered to the AutogradOther entry in the
29
+ // dispatch table when there is both a CompositeImplicitAutograd kernel and a
30
+ // backend kernel for ANY backend that maps to AutogradOther. To see why
31
+ // this is necessary in the AutogradOther case, it's helpful to first see
32
+ // why everything works out fine for a backend that has a reserved Autograd
33
+ // entry (see rule 2.2 in [Note] DispatchTable computation):
34
+ //
35
+ // CPU AutogradCPU
36
+ // reg? registers with...
37
+ // -------------------------------------------------
38
+ // y Autograd registration takes precedence
39
+ // over CompositeImplicitAutograd.
40
+ // This is good, because the CPU specific backend
41
+ // implementation is more specialized and typically better;
42
+ // if we used the composite, we would bypass it.
43
+ // (NB: the Autograd key is guaranteed to exist because
44
+ // the autograd codegen requires it!)
45
+ //
46
+ // n CompositeImplicitAutograd takes precedence.
47
+ // This is also good, because the Autograd
48
+ // registration (if it exists) would try to redispatch
49
+ // to the (non-existent) CPU implementation; by
50
+ // using the composite, we ensure the operator
51
+ // actually works.
52
+ //
53
+ // As you can see, when we have a specific Autograd key (AutogradCPU), we can
54
+ // decide whether or not to use the CompositeImplicitAutograd kernel or the
55
+ // Autograd kernel based on whether or not the backend kernel exists.
56
+ //
57
+ // However, for AutogradOther (which is the catchall autograd kernel for
58
+ // everything that doesn't have a specific Autograd key), we can't do this
59
+ // trick because there isn't any unique backend to peek at to disambiguate;
60
+ // if there are some backends that have implementations they prefer Autograd,
61
+ // but unimplemented backends would prefer CompositeImplicitAutograd. Rather
62
+ // than arbitrarily pick one or the other, we just register a kernel that raises
63
+ // an error and let the user decide how to proceed.
64
+ TORCH_API void ambiguous_autogradother_kernel(
65
+ OperatorKernel*,
66
+ const OperatorHandle&,
67
+ DispatchKeySet,
68
+ Stack*);
69
+
70
+ // Note [named_not_supported_kernel]
71
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72
+ // This kernel implements reporting an error message saying that named tensor is
73
+ // not supported. This kernel doesn't rely on the Stack, and so it is special
74
+ // cased in the dispatcher to be triggered before we attempt boxing (so we can
75
+ // give a good error message in cases when boxing is not supported). When
76
+ // boxing is universally supported this can be removed.
77
+ [[noreturn]] TORCH_API void named_not_supported_kernel(
78
+ OperatorKernel*,
79
+ const OperatorHandle&,
80
+ DispatchKeySet,
81
+ Stack*);
82
+
83
+ /**
84
+ * BoxedKernel is similar to a std::function storing a boxed kernel.
85
+ */
86
+ class TORCH_API BoxedKernel final {
87
+ public:
88
+ // This is how boxed kernels are actually stored
89
+ //
90
+ // Note [Plumbing Keys Through The Dispatcher]
91
+ // Benchmarks have shown that it is expensive for the dispatcher to read from
92
+ // thread-local storage (TLS) upon every dispatch call into order to compute
93
+ // which kernel to dispatch to.
94
+ //
95
+ // To mitigate this, we've updated the calling convention inside the
96
+ // dispatcher to expect every kernel that it stores to have a first argument
97
+ // of type DispatchKeySet.
98
+ //
99
+ // What are the invariants of the DispatchKeySet when it gets passed to a
100
+ // kernel?
101
+ // - All keys to the left of the current dispatch key have been masked out.
102
+ // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the
103
+ // highest bit to be DispatchKey::Tracer)
104
+ // - All other keys that dispatcher normally would have computed through TLS +
105
+ // global state + op arguments
106
+ // are still in the set.
107
+ //
108
+ // Kernels can then opt into using this keyset to save the dispatcher from
109
+ // doing repeated work during redispatches: recalculating the highest-priority
110
+ // dispatch key, which involves reading from TLS. Instead, the kernels that
111
+ // opt in will calculate an updated DispatchKeySet directly from the old one,
112
+ // and pass the updated set directly into the dispatcher upon redispatching.
113
+ //
114
+ // This is an opt-in mechanism: Kernels can automatically opt in by setting
115
+ // the first argument in their signature to be of type DispatchKeySet. See the
116
+ // kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for
117
+ // examples.
118
+ //
119
+ // The mechanism for optionally passing that DispatchKeySet into the kernel
120
+ // lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through
121
+ // The Dispatcher 2] for details.
122
+ using InternalBoxedKernelFunction =
123
+ void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
124
+ // This is the public API for how boxed kernels are defined
125
+ using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
126
+ using BoxedKernelFunction_withDispatchKeys =
127
+ void(const OperatorHandle&, DispatchKeySet, Stack*);
128
+
129
+ BoxedKernel();
130
+
131
+ // Fast path for dispatch to allow not touching the boxed kernel in
132
+ // the common case where unboxed is available.
133
+ bool isValid() const;
134
+ bool isFallthrough() const;
135
+
136
+ /**
137
+ * Call the function with boxed arguments.
138
+ */
139
+ void callBoxed(
140
+ const OperatorHandle& opHandle,
141
+ DispatchKeySet dispatchKeySet,
142
+ Stack* stack) const;
143
+
144
+ /**
145
+ * Create a KernelFunction from a boxed function.
146
+ *
147
+ * Example:
148
+ *
149
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
150
+ * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
151
+ */
152
+ template <BoxedKernelFunction* func>
153
+ static BoxedKernel makeFromFunction();
154
+
155
+ /**
156
+ * TODO: This will only be useful if we write a backend fallback that plumbs
157
+ * dispatch keys (currently there are none) See Note [Plumbing Keys Through
158
+ * The Dispatcher] for details.
159
+ */
160
+ template <BoxedKernelFunction_withDispatchKeys* func>
161
+ static BoxedKernel makeFromFunction();
162
+
163
+ /**
164
+ * Create a KernelFunction from a boxed functor.
165
+ *
166
+ * Example:
167
+ *
168
+ * > class MyFunctor final : public c10::OperatorKernel {
169
+ * > public:
170
+ * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
171
+ * > };
172
+ * > BoxedKernel func =
173
+ * BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
174
+ */
175
+ template <class KernelFunctor>
176
+ static BoxedKernel makeFromFunctor(
177
+ std::unique_ptr<KernelFunctor> kernelFunctor);
178
+
179
+ static BoxedKernel makeFallthrough();
180
+ static BoxedKernel makeAmbiguousAutogradOther();
181
+ static BoxedKernel makeNamedNotSupported();
182
+
183
+ private:
184
+ friend class KernelFunction;
185
+
186
+ template <BoxedKernelFunction* func>
187
+ static void make_boxed_function(
188
+ OperatorKernel*,
189
+ const OperatorHandle& opHandle,
190
+ DispatchKeySet,
191
+ Stack* stack);
192
+
193
+ template <BoxedKernelFunction_withDispatchKeys* func>
194
+ static void make_boxed_function(
195
+ OperatorKernel*,
196
+ const OperatorHandle& opHandle,
197
+ DispatchKeySet,
198
+ Stack* stack);
199
+
200
+ explicit BoxedKernel(
201
+ std::unique_ptr<OperatorKernel> functor,
202
+ InternalBoxedKernelFunction* boxed_kernel_func);
203
+
204
+ OperatorKernel* getFunctor() const;
205
+ InternalBoxedKernelFunction* getFnPtr() const;
206
+
207
+ c10::intrusive_ptr<OperatorKernel> functor_;
208
+ InternalBoxedKernelFunction* boxed_kernel_func_;
209
+ };
210
+
211
+ } // namespace c10
212
+
213
+ #include <ATen/core/boxing/BoxedKernel_impl.h>
.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace c10 {
4
+
5
+ inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {}
6
+
7
+ inline BoxedKernel::BoxedKernel(
8
+ std::unique_ptr<OperatorKernel> functor,
9
+ InternalBoxedKernelFunction* boxed_kernel_func)
10
+ : functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {}
11
+
12
+ template <BoxedKernel::BoxedKernelFunction* func>
13
+ inline void BoxedKernel::make_boxed_function(
14
+ OperatorKernel*,
15
+ const OperatorHandle& opHandle,
16
+ DispatchKeySet,
17
+ Stack* stack) {
18
+ // Note that we're dropping the DispatchKeySet argument.
19
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
20
+ func(opHandle, stack);
21
+ }
22
+
23
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
24
+ inline void BoxedKernel::make_boxed_function(
25
+ OperatorKernel*,
26
+ const OperatorHandle& opHandle,
27
+ DispatchKeySet ks,
28
+ Stack* stack) {
29
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
30
+ func(opHandle, ks, stack);
31
+ }
32
+
33
+ inline bool BoxedKernel::isValid() const {
34
+ return boxed_kernel_func_ != nullptr;
35
+ }
36
+
37
+ inline bool BoxedKernel::isFallthrough() const {
38
+ return boxed_kernel_func_ == &fallthrough_kernel;
39
+ }
40
+
41
+ inline void BoxedKernel::callBoxed(
42
+ const OperatorHandle& opHandle,
43
+ DispatchKeySet dispatchKeySet,
44
+ Stack* stack) const {
45
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
46
+ boxed_kernel_func_ != nullptr,
47
+ "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel.");
48
+ (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
49
+ }
50
+
51
+ template <BoxedKernel::BoxedKernelFunction* func>
52
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
53
+ return BoxedKernel(
54
+ nullptr, // no functor_ object
55
+ &make_boxed_function<func>);
56
+ }
57
+
58
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
59
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
60
+ return BoxedKernel(
61
+ nullptr, // no functor_ object
62
+ &make_boxed_function<func>);
63
+ }
64
+
65
+ inline BoxedKernel BoxedKernel::makeFallthrough() {
66
+ return BoxedKernel(
67
+ nullptr, // no functor_ object
68
+ &fallthrough_kernel);
69
+ }
70
+
71
+ inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
72
+ return BoxedKernel(
73
+ nullptr, // no functor_ object
74
+ &ambiguous_autogradother_kernel);
75
+ }
76
+
77
+ inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
78
+ return BoxedKernel(
79
+ nullptr, // no functor_ object
80
+ &named_not_supported_kernel);
81
+ }
82
+
83
+ template <class KernelFunctor>
84
+ inline BoxedKernel BoxedKernel::makeFromFunctor(
85
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
86
+ static_assert(
87
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
88
+ "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
89
+ return BoxedKernel(
90
+ std::move(kernelFunctor),
91
+ [](OperatorKernel* kernel,
92
+ const OperatorHandle& op,
93
+ DispatchKeySet ks,
94
+ Stack* stack) {
95
+ (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
96
+ });
97
+ }
98
+
99
+ inline OperatorKernel* BoxedKernel::getFunctor() const {
100
+ return functor_.get();
101
+ }
102
+ inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
103
+ return boxed_kernel_func_;
104
+ }
105
+
106
+ } // namespace c10