BryanW commited on
Commit
8f253d2
·
verified ·
1 Parent(s): 175af23

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h +8 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h +18 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h +51 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h +166 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h +53 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h +7 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h +800 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h +29 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h +144 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h +38 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h +401 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h +213 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h +18 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h +53 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h +337 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h +30 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h +194 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h +44 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h +638 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h +208 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h +116 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h +496 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h +358 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h +199 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h +148 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h +192 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h +245 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h +40 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h +27 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h +89 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h +30 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h +19 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h +6 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h +6 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h +103 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h +66 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h +1098 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h +0 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h +22 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h +180 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h +6 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h +26 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h +90 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h +97 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h +99 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h +167 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h +2309 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h +209 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/builtin_function.h +95 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/class_type.h +446 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/macros/Macros.h>
5
+
6
+ #else
7
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
8
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/macros/Export.h>
5
+
6
+ namespace c10 {
7
+ struct OperatorName;
8
+ }
9
+
10
+ namespace at {
11
+
12
+ // check if an op is a custom op (i.e. did not come from native_functions.yaml)
13
+ TORCH_API bool is_custom_op(const c10::OperatorName& opName);
14
+ }
15
+
16
+ #else
17
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
18
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/core/QScheme.h>
4
+
5
+ // Forward declarations of core ATen types used in dispatch functions
6
+ namespace c10 {
7
+
8
+ template<typename T>
9
+ class List;
10
+ template<typename T>
11
+ class IListRef;
12
+ class Stream;
13
+ class Scalar;
14
+ class SymInt;
15
+ class SymIntList;
16
+ struct Storage;
17
+ struct TensorOptions;
18
+ template <typename T>
19
+ class ArrayRef;
20
+ template <typename T>
21
+ class OptionalArrayRef;
22
+
23
+ } // namespace c10
24
+
25
+ namespace at {
26
+
27
+ class Tensor;
28
+ class OptionalTensorRef;
29
+ struct Dimname;
30
+ struct Generator;
31
+ using TensorList = c10::ArrayRef<Tensor>;
32
+ using ITensorListRef = c10::IListRef<Tensor>;
33
+ using IOptTensorListRef = c10::IListRef<OptionalTensorRef>;
34
+ using DimnameList = c10::ArrayRef<Dimname>;
35
+ using IntArrayRef = c10::ArrayRef<int64_t>;
36
+ using OptionalIntArrayRef = c10::OptionalArrayRef<int64_t>;
37
+ using OptionalSymIntArrayRef = c10::OptionalArrayRef<c10::SymInt>;
38
+
39
+ using c10::Stream;
40
+ using c10::Storage;
41
+ using c10::QScheme;
42
+ using c10::Scalar;
43
+ using c10::SymInt;
44
+ using c10::SymIntList;
45
+ using c10::TensorOptions;
46
+
47
+ } // namespace at
48
+
49
+ #else
50
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
51
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // This global header must not depend on native_functions.yaml or
3
+ // incremental builds will be next to useless
4
+ #pragma push_macro("TORCH_ASSERT_NO_OPERATORS")
5
+ #define TORCH_ASSERT_NO_OPERATORS
6
+
7
+ #include <cinttypes>
8
+
9
+ // This list of headers was generated using a script that finds
10
+ // high-impact headers and then manually tweaked to remove OS specific
11
+ // or duplicate headers (e.g. <cassert> and <assert.h>) and to remove
12
+ // "impl" headers (e.g BFloat16-inl.h or complex_math.h in c10).
13
+
14
+ // To generate the initial list:
15
+ // 1. Build pytorch from scratch with all build caching disabled
16
+ // 2. Generate a build trace with ninjatracing (https://github.com/nico/ninjatracing)
17
+ // $ ninjatracing /path/to/pytorch/build/.ninja_log > trace_all.json
18
+ // 3. Run pch_gen.py from https://github.com/peterbell10/build_analysis/
19
+ // $ python pch_gen.py --threshold .80 --target torch_cpu --build_dir /path/to/pytorch/build --trace trace_all.json
20
+ // Where the threshold can be tweaked until c10 and some of ATen
21
+ // core are included but TORCH_ASSERT_NO_OPERATORS still passes.
22
+
23
+ #include <cerrno>
24
+ #include <cmath>
25
+ #include <cstddef>
26
+ #include <cstdint>
27
+ #include <cstdlib>
28
+ #include <cstring>
29
+
30
+ #include <algorithm>
31
+ #include <array>
32
+ #include <atomic>
33
+ #include <chrono>
34
+ #include <complex>
35
+ #include <deque>
36
+ #include <exception>
37
+ #include <functional>
38
+ #include <initializer_list>
39
+ #include <iomanip>
40
+ #include <iosfwd>
41
+ #include <iterator>
42
+ #include <limits>
43
+ #include <list>
44
+ #include <map>
45
+ #include <memory>
46
+ #include <mutex>
47
+ #include <new>
48
+ #include <numeric>
49
+ #include <ostream>
50
+ #include <sstream>
51
+ #include <stdexcept>
52
+ #include <string>
53
+ #include <string_view>
54
+ #include <tuple>
55
+ #include <type_traits>
56
+ #include <typeindex>
57
+ #include <typeinfo>
58
+ #include <unordered_map>
59
+ #include <unordered_set>
60
+ #include <utility>
61
+ #include <vector>
62
+
63
+ #include <c10/core/Allocator.h>
64
+ #include <c10/core/AutogradState.h>
65
+ #include <c10/core/Backend.h>
66
+ #include <c10/core/DefaultDtype.h>
67
+ #include <c10/core/Device.h>
68
+ #include <c10/core/DeviceType.h>
69
+ #include <c10/core/DispatchKey.h>
70
+ #include <c10/core/DispatchKeySet.h>
71
+ #include <c10/core/GeneratorImpl.h>
72
+ #include <c10/core/InferenceMode.h>
73
+ #include <c10/core/Layout.h>
74
+ #include <c10/core/MemoryFormat.h>
75
+ #include <c10/core/OptionalRef.h>
76
+ #include <c10/core/QScheme.h>
77
+ #include <c10/core/Scalar.h>
78
+ #include <c10/core/ScalarType.h>
79
+ #include <c10/core/ScalarTypeToTypeMeta.h>
80
+ #include <c10/core/Storage.h>
81
+ #include <c10/core/StorageImpl.h>
82
+ #include <c10/core/SymBool.h>
83
+ #include <c10/core/SymFloat.h>
84
+ #include <c10/core/SymInt.h>
85
+ #include <c10/core/SymIntArrayRef.h>
86
+ #include <c10/core/SymNodeImpl.h>
87
+ #include <c10/core/TensorImpl.h>
88
+ #include <c10/core/TensorOptions.h>
89
+ #include <c10/core/UndefinedTensorImpl.h>
90
+ #include <c10/core/WrapDimMinimal.h>
91
+ #include <c10/core/impl/LocalDispatchKeySet.h>
92
+ #include <c10/core/impl/PyInterpreter.h>
93
+ #include <c10/core/impl/SizesAndStrides.h>
94
+
95
+ #include <c10/macros/Export.h>
96
+ #include <c10/macros/Macros.h>
97
+
98
+ #include <c10/util/AlignOf.h>
99
+ #include <c10/util/ArrayRef.h>
100
+ #include <c10/util/BFloat16.h>
101
+ #include <c10/util/C++17.h>
102
+ #include <c10/util/ConstexprCrc.h>
103
+ #include <c10/util/Deprecated.h>
104
+ #include <c10/util/DimVector.h>
105
+ #include <c10/util/Exception.h>
106
+ #include <c10/util/ExclusivelyOwned.h>
107
+ #include <c10/util/Flags.h>
108
+ #include <c10/util/Float8_e4m3fn.h>
109
+ #include <c10/util/Float8_e5m2.h>
110
+ #include <c10/util/Float8_e4m3fnuz.h>
111
+ #include <c10/util/Float8_e5m2fnuz.h>
112
+ #include <c10/util/FunctionRef.h>
113
+ #include <c10/util/Half.h>
114
+ #include <c10/util/IdWrapper.h>
115
+ #include <c10/util/Logging.h>
116
+ #include <c10/util/MaybeOwned.h>
117
+ #include <c10/util/Metaprogramming.h>
118
+ #include <c10/util/Optional.h>
119
+ #include <c10/util/Registry.h>
120
+ #include <c10/util/SmallVector.h>
121
+ #include <c10/util/StringUtil.h>
122
+ #include <c10/util/ThreadLocalDebugInfo.h>
123
+ #include <c10/util/Type.h>
124
+ #include <c10/util/TypeCast.h>
125
+ #include <c10/util/TypeIndex.h>
126
+ #include <c10/util/TypeList.h>
127
+ #include <c10/util/TypeSafeSignMath.h>
128
+ #include <c10/util/TypeTraits.h>
129
+ #include <c10/util/UniqueVoidPtr.h>
130
+ #include <c10/util/accumulate.h>
131
+ #include <c10/util/bit_cast.h>
132
+ #include <c10/util/bits.h>
133
+ #include <c10/util/complex.h>
134
+ #include <c10/util/floating_point_utils.h>
135
+ #include <c10/util/intrusive_ptr.h>
136
+ #include <c10/util/irange.h>
137
+ #include <c10/util/llvmMathExtras.h>
138
+ #include <c10/util/python_stub.h>
139
+ #include <c10/util/qint32.h>
140
+ #include <c10/util/qint8.h>
141
+ #include <c10/util/quint2x4.h>
142
+ #include <c10/util/quint4x2.h>
143
+ #include <c10/util/quint8.h>
144
+ #include <c10/util/safe_numerics.h>
145
+ #include <c10/util/string_utils.h>
146
+ #include <c10/util/string_view.h>
147
+ #include <c10/util/typeid.h>
148
+
149
+ #include <ATen/StorageUtils.h>
150
+ #include <ATen/core/ATen_fwd.h>
151
+ #include <ATen/core/DeprecatedTypeProperties.h>
152
+ #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
153
+ #include <ATen/core/DimVector.h>
154
+ #include <ATen/core/Dimname.h>
155
+ #include <ATen/core/Generator.h>
156
+ #include <ATen/core/NamedTensor.h>
157
+ #include <ATen/core/QuantizerBase.h>
158
+ #include <ATen/core/TensorAccessor.h>
159
+ #include <ATen/core/TensorBase.h>
160
+ #include <ATen/core/symbol.h>
161
+
162
+ #pragma pop_macro("TORCH_ASSERT_NO_OPERATORS")
163
+
164
+ #else
165
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
166
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // A fixed-size array type usable from both host and
5
+ // device code.
6
+
7
+ #include <c10/macros/Macros.h>
8
+ #include <c10/util/irange.h>
9
+
10
+ namespace at::detail {
11
+
12
+ template <typename T, int size_>
13
+ struct Array {
14
+ // NOLINTNEXTLINE(*c-array*)
15
+ T data[size_];
16
+
17
+ C10_HOST_DEVICE T operator[](int i) const {
18
+ return data[i];
19
+ }
20
+ C10_HOST_DEVICE T& operator[](int i) {
21
+ return data[i];
22
+ }
23
+ #if defined(USE_ROCM)
24
+ C10_HOST_DEVICE Array() = default;
25
+ C10_HOST_DEVICE Array(const Array&) = default;
26
+ C10_HOST_DEVICE Array& operator=(const Array&) = default;
27
+ C10_HOST_DEVICE Array(Array&&) = default;
28
+ C10_HOST_DEVICE Array& operator=(Array&&) = default;
29
+ C10_HOST_DEVICE ~Array() = default;
30
+ #else
31
+ Array() = default;
32
+ Array(const Array&) = default;
33
+ Array& operator=(const Array&) = default;
34
+ Array(Array&&) noexcept = default;
35
+ Array& operator=(Array&&) noexcept = default;
36
+ ~Array() = default;
37
+ #endif
38
+ static constexpr int size() {
39
+ return size_;
40
+ }
41
+ // Fill the array with x.
42
+ C10_HOST_DEVICE Array(T x) {
43
+ for (int i = 0; i < size_; i++) {
44
+ data[i] = x;
45
+ }
46
+ }
47
+ };
48
+
49
+ } // namespace at::detail
50
+
51
+ #else
52
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
53
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/util/Backtrace.h>
3
+ #include <c10/util/Type.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/Allocator.h>
5
+ #include <c10/core/AllocatorConfig.h>
6
+ #include <c10/core/Stream.h>
7
+ #include <c10/core/thread_pool.h>
8
+ #include <c10/util/flat_hash_map.h>
9
+ #include <c10/util/llvmMathExtras.h>
10
+ #include <iostream>
11
+ #include <optional>
12
+
13
+ #include <deque>
14
+ #include <mutex>
15
+
16
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
17
+ namespace at {
18
+
19
+ using c10::CachingAllocator::Stat;
20
+ using c10::CachingAllocator::DurationStat;
21
+
22
+ /**
23
+ * HostBlock is typically a fundamental memory block used in pinned memory. It
24
+ * is likely related to Event and Stream of device runtime. It is probably a
25
+ * base struct or interface that can be inherited and extended by each backend.
26
+ */
27
+ template <typename S>
28
+ struct HostBlock {
29
+ // constructor for search key
30
+ HostBlock(size_t size) : size_(size) {}
31
+
32
+ HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {}
33
+
34
+ std::mutex mutex_;
35
+ size_t size_{0}; // block size in bytes
36
+ void* ptr_{nullptr}; // memory address
37
+ bool allocated_{false}; // in-use flag
38
+ size_t event_count_{0}; // number of related events
39
+ ska::flat_hash_set<S> streams_; // streams on which the block was used
40
+ };
41
+
42
+ template <typename B>
43
+ struct alignas(hardware_destructive_interference_size) FreeBlockList {
44
+ std::mutex mutex_;
45
+ std::deque<B*> list_;
46
+ };
47
+
48
+ namespace {
49
+ // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes
50
+ // NOLINTNEXTLINE(misc-definitions-in-headers)
51
+ constexpr size_t MAX_SIZE_INDEX = 64;
52
+ }
53
+
54
+ // A large reserved pinned memory segment that is created in advance which is used
55
+ // to allocate small pinned memory requests to avoid calling into expensive APIs.
56
+ // We never free this memory and move up the pointer as we allocate new blocks
57
+ // and when blocks are freed, they are cached in the free lists.
58
+ struct PinnedReserveSegment {
59
+ PinnedReserveSegment(void *start, size_t size) : start_(start), size_(size),
60
+ current_ptr_(start_), initialized_(true) {}
61
+
62
+ PinnedReserveSegment() : start_(nullptr), size_(0), current_ptr_(nullptr), initialized_(false) {}
63
+
64
+ bool initialized() {
65
+ return initialized_;
66
+ }
67
+
68
+ void* allocate(size_t bytes) {
69
+ std::lock_guard<std::mutex> guard(mutex_);
70
+
71
+ // Round up the requested size to 4KB boundary for all including the small ones.
72
+ size_t rounded_bytes = (bytes + 4096 - 1) & ~(4096 - 1);
73
+
74
+ if (((uint8_t*)current_ptr_ + rounded_bytes) > ((uint8_t*)start_ + size_)) {
75
+ return nullptr;
76
+ }
77
+
78
+ void* ptr = current_ptr_;
79
+ current_ptr_ = (uint8_t*)current_ptr_ + rounded_bytes;
80
+ return ptr;
81
+ }
82
+
83
+ bool owns(void* ptr) {
84
+ return ptr >= start_ && ptr < (uint8_t*)start_ + size_;
85
+ }
86
+
87
+ std::mutex mutex_;
88
+ void* start_;
89
+ size_t size_;
90
+ void* current_ptr_;
91
+ bool initialized_;
92
+ };
93
+
94
+ // Struct containing memory allocator summary statistics for host.
95
+ struct TORCH_API HostStats {
96
+ // COUNT: total allocations (active)
97
+ Stat active_requests;
98
+ // SUM: bytes allocated/reserved by this memory allocator. (active)
99
+ Stat active_bytes;
100
+ // COUNT: total allocations (active + free)
101
+ Stat allocations;
102
+ // SUM: bytes allocated/reserved by this memory allocator. This accounts
103
+ // for both free and in-use blocks.
104
+ Stat allocated_bytes;
105
+
106
+ // SUM: time spent in cudaHostAlloc/cudaHostRegister in microseconds
107
+ DurationStat host_alloc_time;
108
+
109
+ // SUM: time spent in cudaHostFree/cudaHostUnregister in microseconds
110
+ DurationStat host_free_time;
111
+
112
+ // COUNT: number of times cudaHostAlloc/cudaHostRegister was called because
113
+ // the request could not be satisfied from existing free blocks.
114
+ int64_t num_host_alloc = 0; // This is derived from segment or timing
115
+
116
+ // COUNT: number of times cudaHostFree/cudaHostUnregister was called.
117
+ int64_t num_host_free = 0; // This is derived from segment or timing
118
+
119
+ // Count of cudaHostAlloc/cudaHostRegister per bucket
120
+ std::vector<int64_t> bucket_allocation = std::vector<int64_t>(MAX_SIZE_INDEX);
121
+ };
122
+
123
+ // Struct containing memory allocator summary statistics for host, as they
124
+ // are staged for reporting. This is a temporary struct that is used to
125
+ // avoid locking the allocator while collecting stats.
126
+ struct alignas(hardware_destructive_interference_size) HostStatsStaged {
127
+ std::mutex timing_mutex_;
128
+ // COUNT: total allocations (active + free)
129
+ // LOCK: access to this stat is protected by the allocator's blocks_mutex_
130
+ Stat allocations;
131
+ // SUM: bytes allocated/reserved by this memory allocator. This accounts
132
+ // for both free and in-use blocks.
133
+ Stat allocated_bytes;
134
+ // COUNT: number of allocations per bucket (active)
135
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
136
+ std::vector<Stat> active_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
137
+ // SUM: bytes of allocation per bucket (active)
138
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
139
+ std::vector<Stat> active_bytes_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
140
+ // COUNT: number of allocations per bucket (active + free)
141
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
142
+ std::vector<Stat> allocation_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
143
+ // SUM: bytes of allocation per bucket (active + free)
144
+ // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
145
+ std::vector<Stat> allocated_bytes_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
146
+ // SUM: time spent in cudaHostAlloc/cudaHostRegister
147
+ // LOCK: access to this stat is protected by the timing_mutex_
148
+ DurationStat host_alloc_time;
149
+ // SUM: time spent in cudaHostFree/cudaHostUnregister
150
+ // LOCK: access to this stat is protected by the timing_mutex_
151
+ DurationStat host_free_time;
152
+ };
153
+
154
+ /**
155
+ * Note [HostAllocator design]
156
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
157
+ * We have three key data structures - the free list which stores blocks that
158
+ * are not currently used, the block list which stores all blocks that have been
159
+ * allocated, and the event queue which stores runtime events and their
160
+ * corresponding blocks.
161
+ *
162
+ * Each of these are protected by a separate mutex. The key design principles
163
+ * are to 1) only hold each mutex for the minimal amount of time possible, 2)
164
+ * never do any possible expensive operations (such as CUDA runtime API calls)
165
+ * while holding the lock.
166
+ *
167
+ * There are four public methods: allocate, free, record_event and empty_cache.
168
+ * 1) In the allocate path, we first check to see if we can service our
169
+ * request from this free list, and otherwise we create a new block with
170
+ * allocate_host_memory.
171
+ * 2) In the free path, we insert events (if required) into the event queue,
172
+ * and if possible insert our block back into the free list. In allocate, we
173
+ * first eagerly query events until we find one that is not ready, and insert
174
+ * the corresponding block onto the free list if all the events recorded for a
175
+ * block are ready.
176
+ * 3) In the record_event path, we simply insert the given stream into the set
177
+ * of streams tracked by the specified block. This set of streams is then
178
+ * consumed in the free path.
179
+ * 4) In the empty_cache path, we flush any available blocks into the free
180
+ * list. Remove all element of free list, then remove them from block list and
181
+ * release the associated pinned memory allocation via free_block.
182
+ *
183
+ * We generalize the caching host allocator into two parts: interface and
184
+ * implementation. For any new backend looking to integrate with host allocator
185
+ * and reuse caching mechanism, these two parts are necessary to be specialized.
186
+ *
187
+ * For the implementation, we provide a CachingHostAllocatorImpl struct
188
+ * to abstract the caching mechanism. Any backend needs to provide a customized
189
+ * implementation by specializing its own public functions and the related
190
+ * runtime functions. Its template parameter S represents runtime Stream, E
191
+ * denotes runtime Event, B indicates the fundamental memory block.
192
+ *
193
+ * For the interface, we provide a CachingHostAllocatorInterface struct as an
194
+ * interface. Any backend needs to derive its own host allocator from this
195
+ * interface. Its template parameter T refers to an implementation that
196
+ * inherited from CachingHostAllocatorImpl.
197
+ *
198
+ * So this design can share the caching mechanism across each backend, and
199
+ * provide flexibility to each backend. A backend can choose to follow this
200
+ * implementation or reuse them by extending and overriding them as necessary.
201
+ * Taking CUDA as an example, it specializes runtime related functions to reuse
202
+ * the caching mechanism. Additionally, it extends the allocator's functionality
203
+ * by adding the allocWithCudaHostRegister function to support page-locking the
204
+ * memory range used by CUDA. Of course, you can also refer to
205
+ * XPUCachingHostAllocator, which is a host caching allocator supported on XPU
206
+ * backend, to implement a basic host caching allocator.
207
+ *
208
+ * Some of the invariants here are less strict than they could be - for example,
209
+ * we do not enforce that free(Block* block) => block->event_count == 0. This is
210
+ * for compatibility reasons, and we can explore enforcing these in subsequent
211
+ * versions.
212
+ *
213
+ * Note that this caching host allocator does not split larger allocations into
214
+ * smaller blocks, unlike the caching device allocator.
215
+ *
216
+ * In order to gather statistics about caching host allocator while minimally
217
+ * impacting performance, we use a HostStatsStaged struct to stage the stats
218
+ * before reporting them. This is done to avoid adding new locks to the allocator.
219
+ * Collecting stats is carefully done under existing locks, and then the staged
220
+ * stats are converted to the final stats when getStats is called. At that time
221
+ * we hold the same locks as empty_cache, to ensure the fidelity of the stats.
222
+ */
223
+
224
+ template <
225
+ typename S,
226
+ typename E,
227
+ typename B = HostBlock<S>>
228
+ struct CachingHostAllocatorImpl {
229
+ virtual ~CachingHostAllocatorImpl() {
230
+ if (active_) {
231
+ active_ = false;
232
+ getBackgroundThreadPool()->waitWorkComplete();
233
+ }
234
+ }
235
+
236
+ public:
237
+ // return data_ptr and block pair.
238
+ virtual std::pair<void*, void*> allocate(size_t size) {
239
+ if (size == 0) {
240
+ return {nullptr, nullptr};
241
+ }
242
+
243
+ // If we are using background threads, we can process events in the
244
+ // background.
245
+ if (!pinned_use_background_threads()) {
246
+ process_events();
247
+ }
248
+
249
+ // Round up the allocation to the nearest power of two to improve reuse.
250
+ // These power of two sizes are also used to index into the free list.
251
+ size_t roundSize = c10::llvm::PowerOf2Ceil(size);
252
+
253
+ // First, try to allocate from the free list
254
+ auto* block = get_free_block(roundSize);
255
+ if (block) {
256
+ return {block->ptr_, reinterpret_cast<void*>(block)};
257
+ }
258
+
259
+ // Check in the recently freed blocks with pending events to see if we
260
+ // can reuse them. Call get_free_block again after processing events
261
+ if (pinned_use_background_threads()) {
262
+ // Launch the background thread and process events in a loop.
263
+ static bool background_thread_flag [[maybe_unused]] = [this] {
264
+ active_ = true;
265
+ getBackgroundThreadPool()->run([&]() {
266
+ while (active_) {
267
+ process_events();
268
+ std::this_thread::sleep_for(std::chrono::microseconds(100));
269
+ }
270
+ });
271
+ return true;
272
+ }();
273
+ }
274
+
275
+ // Slow path: if we can't allocate from the cached free list, we need
276
+ // to create a new block.
277
+ void* ptr = nullptr;
278
+ allocate_host_memory(roundSize, &ptr);
279
+
280
+ // Then, create a new block.
281
+ block = new B(roundSize, ptr);
282
+ block->allocated_ = true;
283
+
284
+ add_allocated_block(block);
285
+ return {block->ptr_, reinterpret_cast<void*>(block)};
286
+ }
287
+
288
+ virtual void free(void* ctx) {
289
+ if (!ctx) {
290
+ return;
291
+ }
292
+
293
+ // Note: we can assume that free is correctly paired with alloc, and thus we
294
+ // do not need to look up the ctx in blocks_.
295
+ auto* block = reinterpret_cast<B*>(ctx);
296
+
297
+ std::optional<std::vector<E>> events;
298
+ ska::flat_hash_set<S> streams;
299
+ {
300
+ std::lock_guard<std::mutex> g(block->mutex_);
301
+ block->allocated_ = false;
302
+ if (block->streams_.empty()) {
303
+ TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
304
+ } else {
305
+ events = std::vector<E>();
306
+ events->reserve(block->streams_.size());
307
+ block->event_count_ += block->streams_.size();
308
+ // Move out streams to avoid holding the mutex during event recording
309
+ streams = std::move(block->streams_);
310
+ block->streams_.clear();
311
+ }
312
+ }
313
+
314
+ // Event recording must be done outside the mutex to avoid potential
315
+ // deadlocks (e.g., when Python GIL is involved)
316
+ for (auto stream : streams) {
317
+ record_stream(events, stream);
318
+ }
319
+
320
+ if (!events) {
321
+ auto index = size_index(block->size_);
322
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
323
+ free_list_[index].list_.push_back(block);
324
+ } else {
325
+ // restore these events that record by used streams.
326
+ std::lock_guard<std::mutex> g(events_mutex_);
327
+ for (auto&& event : *events) {
328
+ events_.emplace_front(std::move(event), block);
329
+ }
330
+ }
331
+ }
332
+
333
+ virtual bool record_event(void* ptr, void* ctx, c10::Stream s) {
334
+ S stream = S(s);
335
+ auto* block = reinterpret_cast<B*>(ctx);
336
+
337
+ // Note: we need to check if the passed-in `ctx` is valid. This is because
338
+ // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on
339
+ // an arbitrary tensor, and is not guaranteed to correspond to a pinned
340
+ // memory allocation. Therefore, we need to check that `ctx` is valid before
341
+ // proceeding.
342
+ {
343
+ std::lock_guard<std::mutex> g(blocks_mutex_);
344
+ if (blocks_.find(block) != blocks_.end()) {
345
+ // Now we know this object is safe to access.
346
+ std::lock_guard<std::mutex> gb(block->mutex_);
347
+ TORCH_INTERNAL_ASSERT(block->allocated_);
348
+ block->streams_.insert(stream);
349
+ return true;
350
+ }
351
+ auto it = ptr_to_block_.find(ptr);
352
+ if (it != ptr_to_block_.end()) {
353
+ block = it->second;
354
+ std::lock_guard<std::mutex> g(block->mutex_);
355
+ TORCH_INTERNAL_ASSERT(block->allocated_);
356
+ block->streams_.insert(stream);
357
+ return true;
358
+ }
359
+ }
360
+
361
+ return false;
362
+ }
363
+
364
+ virtual void empty_cache() {
365
+ // Flush any available blocks into the free_list.
366
+ process_events();
367
+
368
+ // Remove all elements from the free list, remove them from the blocks
369
+ // list, and free the associated pinned memory allocation. This requires
370
+ // concurrently holding both the free list mutexes and the blocks mutex, and
371
+ // is the only function that concurrently holds multiple mutexes.
372
+ for (size_t i = 0; i < free_list_.size(); ++i) {
373
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
374
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
375
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
376
+
377
+ std::vector<B*> blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end());
378
+ free_list_[i].list_.clear();
379
+
380
+ for (auto* block : blocks_to_remove) {
381
+ blocks_.erase(block);
382
+ ptr_to_block_.erase(block->ptr_);
383
+ auto index = size_index(block->size_);
384
+ free_block(block);
385
+ stats_.allocations.decrease(1);
386
+ stats_.allocated_bytes.decrease(block->size_);
387
+ stats_.allocation_bucket_stats[index].decrease(1);
388
+ stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
389
+ delete block;
390
+ }
391
+ }
392
+ }
393
+
394
+ inline size_t size_index(size_t size) {
395
+ return c10::llvm::Log2_64_Ceil(size);
396
+ }
397
+
398
+ virtual bool pinned_use_background_threads() {
399
+ return c10::CachingAllocator::AcceleratorAllocatorConfig::
400
+ pinned_use_background_threads();
401
+ }
402
+
403
+ virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
404
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
405
+ }
406
+
407
+ HostStats getStats() {
408
+ HostStats stats;
409
+
410
+ // To keep getStats lightweight we do *not* flush any available blocks
411
+ // into the free_list. This may skew the stats a bit.
412
+
413
+ auto add_bucket_stats = [](Stat& accumulator, const Stat& other) {
414
+ accumulator.allocated += other.allocated;
415
+ accumulator.current += other.current;
416
+ accumulator.freed += other.freed;
417
+ // Since peaks are measured per bucket independently, we add them up
418
+ // to estimate the total peak. This is not strictly correct, but it is
419
+ // the best approximation we can get after the fact.
420
+ accumulator.peak += other.peak;
421
+ };
422
+
423
+ // Accurate reading of memory stats requires concurrently holding both the
424
+ // free list mutexes and the blocks mutex. Previously, this was only done in
425
+ // empty_cache function.
426
+ for (size_t i = 0; i < free_list_.size(); ++i) {
427
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
428
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
429
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
430
+
431
+ // We collect the slow-path stats only once, since they are not collected
432
+ // per bucket (we pick index 0 arbitrarily). These are also all the host
433
+ // allocations, not taking into account caching and free lists.
434
+ if (i == 0) {
435
+ stats.allocations = stats_.allocations;
436
+ stats.allocated_bytes = stats_.allocated_bytes;
437
+ stats.num_host_alloc = stats.allocations.allocated;
438
+ stats.num_host_free = stats.allocations.freed;
439
+ }
440
+
441
+ // Bucket stats need to be merged with the slow-path stats. We do this in
442
+ // a best effort manner, since we can't really replay the cached events per bucket.
443
+ add_bucket_stats(stats.active_requests, stats_.active_bucket_stats[i]);
444
+ add_bucket_stats(stats.active_bytes, stats_.active_bytes_bucket_stats[i]);
445
+ stats.bucket_allocation[i] = stats_.allocation_bucket_stats[i].allocated;
446
+ }
447
+
448
+ // Get the timing stats
449
+ {
450
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
451
+
452
+ stats.host_alloc_time = stats_.host_alloc_time;
453
+ stats.host_free_time = stats_.host_free_time;
454
+ }
455
+
456
+ return stats;
457
+ }
458
+
459
+ void resetAccumulatedStats() {
460
+ // Resetting accumulated memory stats requires concurrently holding both the
461
+ // free list mutexes and the blocks mutex. Previously, this was only done in
462
+ // empty_cache function.
463
+ for (size_t i = 0; i < free_list_.size(); ++i) {
464
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
465
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
466
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
467
+
468
+ if (i == 0) {
469
+ stats_.allocations.reset_accumulated();
470
+ stats_.allocated_bytes.reset_accumulated();
471
+ }
472
+ stats_.active_bucket_stats[i].reset_accumulated();
473
+ stats_.active_bytes_bucket_stats[i].reset_accumulated();
474
+ stats_.allocation_bucket_stats[i].reset_accumulated();
475
+ stats_.allocated_bytes_bucket_stats[i].reset_accumulated();
476
+ }
477
+
478
+ // Also reset timing stats
479
+ {
480
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
481
+ stats_.host_alloc_time.reset_accumulated();
482
+ stats_.host_free_time.reset_accumulated();
483
+ }
484
+ }
485
+
486
+ void resetPeakStats() {
487
+ // Resetting peak memory stats requires concurrently holding both the
488
+ // free list mutexes and the blocks mutex. Previously, this was only done in
489
+ // empty_cache function.
490
+ for (size_t i = 0; i < free_list_.size(); ++i) {
491
+ std::lock(free_list_[i].mutex_, blocks_mutex_);
492
+ std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
493
+ std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
494
+
495
+ if (i == 0) {
496
+ stats_.allocations.reset_peak();
497
+ stats_.allocated_bytes.reset_peak();
498
+ }
499
+ stats_.active_bucket_stats[i].reset_peak();
500
+ stats_.active_bytes_bucket_stats[i].reset_peak();
501
+ stats_.allocation_bucket_stats[i].reset_peak();
502
+ stats_.allocated_bytes_bucket_stats[i].reset_peak();
503
+ }
504
+
505
+ // Also reset timing stats
506
+ {
507
+ std::lock_guard<std::mutex> g(stats_.timing_mutex_);
508
+ stats_.host_alloc_time.reset_peak();
509
+ stats_.host_free_time.reset_peak();
510
+ }
511
+ }
512
+
513
+ private:
514
+ virtual void add_allocated_block(B* block) {
515
+ std::lock_guard<std::mutex> g(blocks_mutex_);
516
+ blocks_.insert(block);
517
+ stats_.allocations.increase(1);
518
+ stats_.allocated_bytes.increase(block->size_);
519
+ ptr_to_block_.insert({block->ptr_, block});
520
+
521
+ // Unfortunately, we have to, on the slow path, quickly
522
+ // lock the bucket to record the allocation. This should
523
+ // be a rare event once the cache is warmed up.
524
+ auto size = block->size_;
525
+ auto index = size_index(size);
526
+ {
527
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
528
+ stats_.allocation_bucket_stats[index].increase(1);
529
+ stats_.allocated_bytes_bucket_stats[index].increase(size);
530
+ stats_.active_bucket_stats[index].increase(1);
531
+ stats_.active_bytes_bucket_stats[index].increase(size);
532
+ }
533
+ }
534
+
535
+ virtual B* get_free_block(size_t size) {
536
+ auto index = size_index(size);
537
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
538
+ if (!free_list_[index].list_.empty()) {
539
+ B* block = free_list_[index].list_.back();
540
+ free_list_[index].list_.pop_back();
541
+ block->allocated_ = true;
542
+ stats_.active_bucket_stats[index].increase(1);
543
+ stats_.active_bytes_bucket_stats[index].increase(size);
544
+ return block;
545
+ }
546
+ return nullptr;
547
+ }
548
+
549
+ virtual void process_events() {
550
+ // process all events until the last unready event, not for specific size.
551
+ process_events_for_specific_size(-1);
552
+ }
553
+
554
+ // If size is -1, process all events from backwards until the last unready
555
+ // event. Otherwise, process events for a specific size and on first ready block
556
+ // is found, add it to the free list and return.
557
+ virtual void process_events_for_specific_size(int64_t size) {
558
+ size_t event_count = 0;
559
+ size_t max_events = 0;
560
+ {
561
+ std::lock_guard<std::mutex> g(events_mutex_);
562
+ max_events = events_.size();
563
+ }
564
+
565
+ while (true) {
566
+ // Avoid calling cudaEventDestroy while holding a mutex, so move
567
+ // intermediate events out of the lock into this object.
568
+ // process the last event
569
+ std::optional<std::pair<E, B*>> processed;
570
+ {
571
+ std::lock_guard<std::mutex> g(events_mutex_);
572
+ if (!events_.empty()) {
573
+ processed = std::move(events_.back());
574
+ events_.pop_back();
575
+ }
576
+ }
577
+
578
+ if (!processed) {
579
+ return;
580
+ }
581
+
582
+ if (size != -1) {
583
+ if (event_count++ > max_events) {
584
+ {
585
+ std::lock_guard<std::mutex> g(events_mutex_);
586
+ events_.push_front(std::move(*processed));
587
+ }
588
+ return;
589
+ }
590
+ if (size != (int64_t)processed->second->size_) {
591
+ // if we are processing a specific size, and the size of the block
592
+ // doesn't match, we can't use it.
593
+ {
594
+ std::lock_guard<std::mutex> g(events_mutex_);
595
+ events_.push_front(std::move(*processed));
596
+ }
597
+ continue;
598
+ }
599
+ }
600
+
601
+ // otherwise, query the event
602
+ {
603
+ // now, see if we can handle this element
604
+ auto& event = processed->first;
605
+ if (!query_event(event)) {
606
+ // push the event onto the back if it's not ready.
607
+ {
608
+ std::lock_guard<std::mutex> g(events_mutex_);
609
+ if (size == -1) {
610
+ events_.push_back(std::move(*processed));
611
+ return;
612
+ } else {
613
+ events_.push_front(std::move(*processed));
614
+ continue;
615
+ }
616
+ }
617
+ }
618
+ }
619
+
620
+ // Process the events.
621
+ TORCH_INTERNAL_ASSERT(processed);
622
+ auto* block = processed->second;
623
+ bool available = false;
624
+ {
625
+ std::lock_guard<std::mutex> g(block->mutex_);
626
+ TORCH_INTERNAL_ASSERT(!block->allocated_)
627
+ block->event_count_--;
628
+ if (block->event_count_ == 0) {
629
+ available = true;
630
+ }
631
+ }
632
+
633
+ if (available) {
634
+ auto index = size_index(block->size_);
635
+ std::lock_guard<std::mutex> g(free_list_[index].mutex_);
636
+ free_list_[index].list_.push_back(block);
637
+ stats_.active_bucket_stats[index].decrease(1);
638
+ stats_.active_bytes_bucket_stats[index].decrease(size);
639
+ if (size != -1) {
640
+ return;
641
+ }
642
+ }
643
+ }
644
+ }
645
+
646
+ TaskThreadPool* getBackgroundThreadPool() {
647
+ static TaskThreadPool* pool = new TaskThreadPool(1);
648
+ return pool;
649
+ }
650
+
651
+ /* These following functions are runtime-related. */
652
+
653
+ // Allocate page-locked memory on the host.
654
+ virtual void allocate_host_memory(size_t size, void** ptr) {
655
+ TORCH_CHECK_NOT_IMPLEMENTED(
656
+ false, "Not implemented for allocate_host_memory");
657
+ }
658
+
659
+ // Free block and release the pointer contained in block.
660
+ virtual void free_block(B* block) {
661
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
662
+ }
663
+
664
+ // Record an event on stream and store event into events.
665
+ virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
666
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
667
+ }
668
+
669
+ // Query event if it is completed.
670
+ virtual bool query_event(E& event) {
671
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
672
+ }
673
+
674
+ alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
675
+ ska::flat_hash_set<B*> blocks_; // block list
676
+ ska::flat_hash_map<void*, B*> ptr_to_block_;
677
+
678
+ // We keep free list as a vector of free lists, one for each power of two
679
+ // size. This allows us to quickly find a free block of the right size.
680
+ // We use deque to store per size free list and guard the list with its own
681
+ // mutex.
682
+ alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>>
683
+ free_list_{MAX_SIZE_INDEX};
684
+
685
+ alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
686
+ std::deque<std::pair<E, B*>> events_; // event queue paired with block
687
+
688
+ // Indicates whether the event-processing thread pool is active.
689
+ // Set to false in the destructor to signal background threads to stop.
690
+ std::atomic<bool> active_{false};
691
+ protected:
692
+ alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
693
+ };
694
+
695
+ struct TORCH_API HostAllocator : public at::Allocator {
696
+ // Associates the pinned memory allocation with a stream to track
697
+ // dependencies. This ensures the memory won't be reused until the stream's
698
+ // operations complete
699
+ virtual bool record_event(void* ptr, void* ctx, c10::Stream stream) = 0;
700
+
701
+ // Frees all cached pinned memory and returns it to the system, clearing the
702
+ // allocator's internal cache
703
+ virtual void empty_cache() = 0;
704
+
705
+ // Returns comprehensive statistics about the allocator's memory usage,
706
+ // allocation patterns, and timing metrics
707
+ virtual HostStats get_stats() = 0;
708
+
709
+ // Resets the cumulative allocation statistics
710
+ virtual void reset_accumulated_stats() = 0;
711
+
712
+ // Resets the peak memory usage metrics
713
+ virtual void reset_peak_stats() = 0;
714
+ };
715
+
716
+ template <typename T, c10::DeleterFnPtr deleteFunc>
717
+ struct CachingHostAllocatorInterface : public HostAllocator {
718
+ CachingHostAllocatorInterface() : impl_(std::make_unique<T>()) {}
719
+
720
+ at::DataPtr allocate(size_t size) override {
721
+ auto ptr_and_ctx = impl_->allocate(size);
722
+ return {
723
+ ptr_and_ctx.first,
724
+ ptr_and_ctx.second,
725
+ deleteFunc, // Use the template parameter deleter function
726
+ at::DeviceType::CPU};
727
+ }
728
+
729
+ void free(void* ctx) {
730
+ impl_->free(ctx);
731
+ }
732
+
733
+ bool record_event(void* ptr, void* ctx, c10::Stream stream) override {
734
+ return impl_->record_event(ptr, ctx, stream);
735
+ }
736
+
737
+ void empty_cache() override {
738
+ impl_->empty_cache();
739
+ }
740
+
741
+ void copy_data(void* dest, const void* src, std::size_t count)
742
+ const override {
743
+ impl_->copy_data(dest, src, count);
744
+ }
745
+
746
+ HostStats get_stats() override {
747
+ return impl_->getStats();
748
+ }
749
+
750
+ void reset_accumulated_stats() override {
751
+ impl_->resetAccumulatedStats();
752
+ }
753
+
754
+ void reset_peak_stats() override {
755
+ impl_->resetPeakStats();
756
+ }
757
+
758
+ std::unique_ptr<T> impl_;
759
+ };
760
+
761
+ #define DECLARE_HOST_ALLOCATOR(name, impl, deleter, instance) \
762
+ void deleter(void* ptr); \
763
+ struct name final \
764
+ : public at::CachingHostAllocatorInterface<impl, deleter> {}; \
765
+ static name instance; \
766
+ void deleter(void* ptr) { \
767
+ instance.free(ptr); \
768
+ }
769
+
770
+ /**
771
+ * Set the host allocator for DeviceType `device_type`. This allocator manages
772
+ * pinned memory on the host that can be accessed efficiently by the specified
773
+ * device type. Note that this function is not thread-safe.
774
+ */
775
+ TORCH_API void setHostAllocator(
776
+ at::DeviceType device_type,
777
+ at::HostAllocator* allocator,
778
+ uint8_t priority = 0);
779
+
780
+ TORCH_API at::HostAllocator* getHostAllocator(at::DeviceType device_type);
781
+
782
+ template <DeviceType device_type>
783
+ struct HostAllocatorRegistry {
784
+ explicit HostAllocatorRegistry(HostAllocator* allocator) {
785
+ at::setHostAllocator(device_type, allocator);
786
+ }
787
+ };
788
+
789
+ #define REGISTER_HOST_ALLOCATOR(device_type, allocator) \
790
+ namespace { \
791
+ static at::HostAllocatorRegistry<device_type> \
792
+ g_host_allocator_registry_instance(allocator); \
793
+ }
794
+
795
+ } // namespace at
796
+ C10_DIAGNOSTIC_POP()
797
+
798
+ #else
799
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
800
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/core/TensorOptions.h>
3
+
4
+ namespace c10::impl {
5
+
6
+ inline std::optional<MemoryFormat>
7
+ check_tensor_options_and_extract_memory_format(
8
+ const TensorOptions& options,
9
+ std::optional<MemoryFormat> memory_format) {
10
+ TORCH_CHECK(
11
+ options.requires_grad_opt() != true,
12
+ "Operators taking TensorOptions cannot take a TensorOptions with "
13
+ "options.requires_grad set as true. This isn't implemented yet.");
14
+ TORCH_CHECK(
15
+ !(options.has_memory_format() && memory_format.has_value()),
16
+ "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
17
+ "the redundant setter.");
18
+ if (memory_format.has_value()) {
19
+ return memory_format;
20
+ } else {
21
+ return options.memory_format_opt();
22
+ }
23
+ }
24
+
25
+ } // namespace impl namespace c10
26
+
27
+ #else
28
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
29
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/Backend.h>
5
+ #include <c10/core/ScalarType.h>
6
+ #include <c10/core/Layout.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/core/Storage.h>
9
+ #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
10
+ #include <ATen/core/Generator.h>
11
+
12
+
13
+ namespace at {
14
+
15
+ class Tensor;
16
+
17
+ // This class specifies a Backend and a ScalarType. Currently, it primarily
18
+ // serves as a replacement return value for Tensor::type(). Previously,
19
+ // Tensor::type() returned Type&, but we are changing Type to not be
20
+ // dtype-specific.
21
+ class TORCH_API DeprecatedTypeProperties {
22
+ public:
23
+ DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
24
+ : backend_(backend), scalar_type_(scalar_type) {}
25
+
26
+ Backend backend() const {
27
+ return backend_;
28
+ }
29
+
30
+ Layout layout() const {
31
+ return layout_from_backend(backend_);
32
+ }
33
+
34
+ bool is_sparse() const {
35
+ return layout_from_backend(backend()) == kSparse;
36
+ }
37
+
38
+ bool is_sparse_csr() const {
39
+ return layout_from_backend(backend()) == kSparseCsr;
40
+ }
41
+
42
+ c10::DeviceType device_type() const {
43
+ return backendToDeviceType(backend_);
44
+ }
45
+
46
+ bool is_cuda() const {
47
+ return backendToDeviceType(backend_) == kCUDA;
48
+ }
49
+
50
+ ScalarType scalarType() const {
51
+ return scalar_type_;
52
+ }
53
+
54
+ caffe2::TypeMeta typeMeta() const {
55
+ return scalarTypeToTypeMeta(scalar_type_);
56
+ }
57
+
58
+ bool operator==(const DeprecatedTypeProperties& other) const {
59
+ return backend_ == other.backend() && scalar_type_ == other.scalarType();
60
+ }
61
+
62
+ bool operator!=(const DeprecatedTypeProperties& other) const {
63
+ return !(*this == other);
64
+ }
65
+
66
+ std::string toString() const {
67
+ std::string base_str;
68
+ if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) {
69
+ base_str = "UndefinedType";
70
+ } else {
71
+ base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type";
72
+ }
73
+ return base_str;
74
+ }
75
+
76
+ DeprecatedTypeProperties & toBackend(Backend b) const {
77
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
78
+ b, scalar_type_);
79
+ }
80
+
81
+ DeprecatedTypeProperties & toScalarType(ScalarType s) const {
82
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
83
+ backend_, s);
84
+ }
85
+
86
+ DeprecatedTypeProperties & cpu() const {
87
+ return toBackend(Backend::CPU);
88
+ }
89
+
90
+ DeprecatedTypeProperties & cuda() const {
91
+ return toBackend(Backend::CUDA);
92
+ }
93
+
94
+ DeprecatedTypeProperties & hip() const {
95
+ return toBackend(Backend::HIP);
96
+ }
97
+
98
+ DeprecatedTypeProperties & privateUser1() const {
99
+ return toBackend(Backend::PrivateUse1);
100
+ }
101
+
102
+ /// Constructs the `TensorOptions` from a type and a `device_index`.
103
+ TensorOptions options(int16_t device_index = -1) const {
104
+ return TensorOptions().dtype(typeMeta())
105
+ .device(device_type(), static_cast<c10::DeviceIndex>(device_index))
106
+ .layout(layout());
107
+ }
108
+
109
+ /// Constructs the `TensorOptions` from a type and a Device. Asserts that
110
+ /// the device type matches the device type of the type.
111
+ TensorOptions options(std::optional<Device> device_opt) const {
112
+ if (!device_opt.has_value()) {
113
+ return options(-1);
114
+ } else {
115
+ Device device = device_opt.value();
116
+ AT_ASSERT(device.type() == device_type());
117
+ return options(device.index());
118
+ }
119
+ }
120
+
121
+ operator TensorOptions() const {
122
+ return options();
123
+ }
124
+
125
+ int64_t id() const {
126
+ return static_cast<int64_t>(backend()) *
127
+ static_cast<int64_t>(ScalarType::NumOptions) +
128
+ static_cast<int64_t>(scalarType());
129
+ }
130
+
131
+ Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
132
+ Storage unsafeStorageFromTH(void * th_pointer, bool retain) const;
133
+ Tensor copy(const Tensor & src, bool non_blocking=false, std::optional<Device> to_device={}) const;
134
+
135
+ private:
136
+ Backend backend_;
137
+ ScalarType scalar_type_;
138
+ };
139
+
140
+ } // namespace at
141
+
142
+ #else
143
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
144
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // In order to preserve bc, we make DeprecatedTypeProperties instances unique
5
+ // just like they are for Type.
6
+
7
+ #include <c10/core/Backend.h>
8
+ #include <c10/core/ScalarType.h>
9
+ #include <memory>
10
+
11
+ namespace at {
12
+
13
+ class DeprecatedTypeProperties;
14
+
15
+ struct TORCH_API DeprecatedTypePropertiesDeleter {
16
+ void operator()(DeprecatedTypeProperties * ptr);
17
+ };
18
+
19
+ class TORCH_API DeprecatedTypePropertiesRegistry {
20
+ public:
21
+ DeprecatedTypePropertiesRegistry();
22
+
23
+ DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) const;
24
+
25
+ private:
26
+ // NOLINTNEXTLINE(*c-array*)
27
+ std::unique_ptr<DeprecatedTypeProperties> registry
28
+ [static_cast<int>(Backend::NumOptions)]
29
+ [static_cast<int>(ScalarType::NumOptions)];
30
+ };
31
+
32
+ TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
33
+
34
+ } // namespace at
35
+
36
+ #else
37
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
38
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/macros/Export.h>
6
+ #include <c10/util/TypeTraits.h>
7
+ #include <c10/util/TypeList.h>
8
+ #include <c10/util/intrusive_ptr.h>
9
+ #include <c10/util/order_preserving_flat_hash_map.h>
10
+ #include <optional>
11
+ #include <ATen/core/TensorBody.h>
12
+ #include <ATen/core/jit_type_base.h>
13
+
14
+ namespace c10 {
15
+ struct IValue;
16
+ template<class Key, class Value> class Dict;
17
+ struct Type;
18
+
19
+ namespace impl {
20
+
21
+ using valid_dict_key_types = guts::typelist::typelist<
22
+ int64_t,
23
+ std::string,
24
+ double,
25
+ c10::complex<double>,
26
+ bool,
27
+ at::Tensor
28
+ >;
29
+ }
30
+
31
+ namespace detail {
32
+
33
+ struct DictKeyHash {
34
+ size_t operator()(const IValue& ivalue) const;
35
+ };
36
+
37
+ struct DictKeyEqualTo {
38
+ bool operator()(const IValue& lhs, const IValue& rhs) const;
39
+ };
40
+
41
+ struct DictImpl final : public c10::intrusive_ptr_target {
42
+ using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
43
+ struct DictElementTypes final {
44
+ TypePtr keyType;
45
+ TypePtr valueType;
46
+ };
47
+
48
+ explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_)
49
+ : dict(std::move(dict_))
50
+ , elementTypes(std::move(elementTypes_)) {}
51
+ dict_map_type dict;
52
+
53
+ DictElementTypes elementTypes;
54
+
55
+ intrusive_ptr<DictImpl> copy() const;
56
+ friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs);
57
+ };
58
+
59
+ }
60
+
61
+ namespace impl {
62
+ template<class Key, class Value, class Iterator> class DictIterator;
63
+
64
+ /**
65
+ * A reference to an entry in the Dict.
66
+ * Use the `key()` and `value()` methods to read the element.
67
+ */
68
+ template<class Key, class Value, class Iterator>
69
+ class DictEntryRef final {
70
+ public:
71
+ explicit DictEntryRef(Iterator iterator)
72
+ : iterator_(std::move(iterator)) {}
73
+
74
+ decltype(auto) key() const {
75
+ return iterator_->first.template to<Key>();
76
+ }
77
+
78
+ decltype(auto) value() const {
79
+ return iterator_->second.template to<Value>();
80
+ }
81
+
82
+ template<class Value_>
83
+ void setValue(Value_&& value) const {
84
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of setValue()");
85
+ iterator_->second = Value(std::forward<Value_>(value));
86
+ }
87
+ ~DictEntryRef() = default;
88
+
89
+ private:
90
+ // allow copying and moving, but only our friends (i.e. the Dict class) can do
91
+ // it. Copying/moving this reference wrapper would be too ambiguous to allow it
92
+ // in the public API.
93
+ DictEntryRef(const DictEntryRef&) = default;
94
+ DictEntryRef& operator=(const DictEntryRef&) = default;
95
+ DictEntryRef(DictEntryRef&&) noexcept = default;
96
+ DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default;
97
+
98
+ Iterator iterator_;
99
+ friend class DictIterator<Key, Value, Iterator>;
100
+ friend class Dict<Key, Value>;
101
+ };
102
+
103
+ // this wraps map_type::iterator to make sure user code can't rely
104
+ // on it being the type of the underlying map.
105
+ template<class Key, class Value, class Iterator>
106
+ class DictIterator final {
107
+ public:
108
+ // C++17 friendly std::iterator implementation
109
+ using iterator_category = std::forward_iterator_tag;
110
+ using value_type = DictEntryRef<Key, Value, Iterator>;
111
+ using difference_type = std::ptrdiff_t;
112
+ using pointer = value_type*;
113
+ using reference = value_type&;
114
+
115
+ explicit DictIterator() = default;
116
+ ~DictIterator() = default;
117
+
118
+ DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
119
+ DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
120
+ DictIterator& operator=(const DictIterator& rhs) = default;
121
+ DictIterator& operator=(DictIterator&& rhs) noexcept {
122
+ entryRef_ = std::move(rhs.entryRef_);
123
+ return *this;
124
+ }
125
+
126
+ DictIterator& operator++() {
127
+ ++entryRef_.iterator_;
128
+ return *this;
129
+ }
130
+
131
+ DictIterator operator++(int) {
132
+ DictIterator copy(*this);
133
+ ++*this;
134
+ return copy;
135
+ }
136
+
137
+ const DictEntryRef<Key, Value, Iterator>& operator*() const {
138
+ return entryRef_;
139
+ }
140
+
141
+ const DictEntryRef<Key, Value, Iterator>* operator->() const {
142
+ return &entryRef_;
143
+ }
144
+
145
+ friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) {
146
+ return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_;
147
+ }
148
+
149
+ private:
150
+ explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {}
151
+
152
+ const Iterator& get_iterator_() const {
153
+ return entryRef_.iterator_;
154
+ }
155
+
156
+ friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) {
157
+ return lhs.get_iterator_() == rhs.get_iterator_();
158
+ }
159
+
160
+ friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) {
161
+ return lhs.get_iterator_() != rhs.get_iterator_();
162
+ }
163
+
164
+ friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) {
165
+ return lhs.get_iterator_() < rhs.get_iterator_();
166
+ }
167
+
168
+ friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) {
169
+ return lhs.get_iterator_() <= rhs.get_iterator_();
170
+ }
171
+
172
+ friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) {
173
+ return lhs.get_iterator_() > rhs.get_iterator_();
174
+ }
175
+
176
+ friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) {
177
+ return lhs.get_iterator_() >= rhs.get_iterator_();
178
+ }
179
+
180
+ DictEntryRef<Key, Value, Iterator> entryRef_;
181
+
182
+ friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>;
183
+ friend class Dict<Key, Value>;
184
+ };
185
+
186
+ template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict);
187
+ template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict);
188
+ }
189
+
190
+ /**
191
+ * An object of this class stores a map from Key to Value.
192
+ *
193
+ * This is a pointer type. After a copy, both Dicts
194
+ * will share the same storage:
195
+ *
196
+ * > Dict<int, string> a;
197
+ * > Dict<int, string> b = a;
198
+ * > b.insert(3, "three");
199
+ * > ASSERT("three" == a.at(3));
200
+ *
201
+ * We use this class in the PyTorch kernel API because that
202
+ * allows us to do optimizations and switch out the underlying
203
+ * map implementation without breaking backwards compatibility
204
+ * for the kernel API.
205
+ */
206
+ template<class Key, class Value>
207
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
208
+ class Dict final {
209
+ private:
210
+ static_assert((std::is_same_v<IValue, Key> && std::is_same_v<IValue, Value>) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string.");
211
+
212
+ // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map.
213
+ // We intentionally don't offer conversion from/to
214
+ // order_preserving_flat_hash_map, return references to it or something like that,
215
+ // because such operations would get expensive if we switch out
216
+ // the actual map implementation.
217
+ // This is an intrusive_ptr because Dict is a pointer type.
218
+ // Invariant: This will never be a nullptr, there will always be a valid
219
+ // DictImpl.
220
+ c10::intrusive_ptr<detail::DictImpl> impl_;
221
+
222
+ explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl);
223
+ friend struct IValue;
224
+ template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>);
225
+ template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>);
226
+
227
+ public:
228
+ using key_type = Key;
229
+ using mapped_type = Value;
230
+ using size_type = typename detail::DictImpl::dict_map_type::size_type;
231
+ using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
232
+
233
+ /**
234
+ * Creates an empty dict.
235
+ */
236
+ explicit Dict();
237
+
238
+ /**
239
+ * Create a generic dict with runtime type information.
240
+ * This only works for c10::impl::GenericDict and is not part of the public API
241
+ * but only supposed to be used internally by PyTorch.
242
+ */
243
+ explicit Dict(TypePtr keyType, TypePtr valueType);
244
+
245
+ ~Dict() = default;
246
+
247
+ Dict(const Dict&) = default;
248
+ Dict& operator=(const Dict&) = default;
249
+
250
+ /**
251
+ * Create a new Dict pointing to a deep copy of the same data.
252
+ * The Dict returned is a new dict with separate storage.
253
+ * Changes in it are not reflected in the original dict or vice versa.
254
+ */
255
+ Dict copy() const;
256
+
257
+ /**
258
+ * Returns an iterator to the first element of the container.
259
+ * If the container is empty, the returned iterator will be equal to end().
260
+ */
261
+ iterator begin() const;
262
+
263
+ /**
264
+ * Returns an iterator to the element following the last element of the container.
265
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
266
+ */
267
+ iterator end() const;
268
+
269
+ /**
270
+ * Checks if the container has no elements.
271
+ */
272
+ bool empty() const;
273
+
274
+ /**
275
+ * Returns the number of elements in the container.
276
+ */
277
+ size_type size() const;
278
+
279
+ /**
280
+ * Erases all elements from the container. After this call, size() returns zero.
281
+ * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators.
282
+ */
283
+ void clear() const;
284
+
285
+ /**
286
+ * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key.
287
+ * May invalidate any references, pointers, or iterators referring to contained elements.
288
+ *
289
+ * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place.
290
+ */
291
+ template<class Key_, class Value_>
292
+ std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const;
293
+
294
+ /**
295
+ * If an element with the given key already exists, it is overwritten with the given value.
296
+ * Otherwise, a new element with the given key and value are inserted.
297
+ * May invalidate any references, pointers, or iterators referring to contained elements.
298
+ *
299
+ * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated.
300
+ */
301
+ template<class Key_, class Value_>
302
+ std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const;
303
+
304
+ /**
305
+ * Removes the element pointed to by iter.
306
+ * May invalidate any references, pointers, or iterators referring to contained elements.
307
+ * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter.
308
+ */
309
+ void erase(iterator iter) const;
310
+
311
+ /**
312
+ * Removes the element with the given key, if it exists.
313
+ * May invalidate any references, pointers, or iterators referring to contained elements.
314
+ *
315
+ * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't.
316
+ */
317
+ [[nodiscard]] size_t erase(const Key& key) const;
318
+
319
+ /**
320
+ * Returns the mapped value of the element with key equivalent to key.
321
+ * If no such element exists, an exception of type std::out_of_range is thrown.
322
+ */
323
+ Value at(const Key& key) const;
324
+
325
+ /**
326
+ * Finds an element with key equivalent to key.
327
+ *
328
+ * @return Iterator to an element with key equivalent to key.
329
+ * If no such element is found, past-the-end (see end()) iterator is returned.
330
+ */
331
+ iterator find(const Key& key) const;
332
+
333
+ /**
334
+ * Checks if there is an element with key equivalent to key in the container.
335
+ *
336
+ * @return true if there is such an element, otherwise false.
337
+ */
338
+ bool contains(const Key& key) const;
339
+
340
+ /**
341
+ * Increase the capacity so that at least count elements can be stored without
342
+ * having to reallocate or rehash.
343
+ */
344
+ void reserve(size_type count) const;
345
+
346
+ /**
347
+ * Value equality comparison. This function implements Python-like semantics for
348
+ * equality: two dicts with the same identity (e.g. same pointer) trivially
349
+ * compare equal, otherwise each element is compared for equality.
350
+ */
351
+ template <class Key_, class Value_>
352
+ friend bool operator==(
353
+ const Dict<Key_, Value_>& lhs,
354
+ const Dict<Key_, Value_>& rhs);
355
+ template <class Key_, class Value_>
356
+ friend bool operator!=(
357
+ const Dict<Key_, Value_>& lhs,
358
+ const Dict<Key_, Value_>& rhs);
359
+
360
+ /**
361
+ * Identity comparison. Returns true if and only if `rhs` represents the same
362
+ * Dict object as `this`.
363
+ */
364
+ bool is(const Dict& rhs) const;
365
+
366
+ // private API for now because the return type will change to TypePtr
367
+ // instead of std::optional<TypePtr> once types are mandatory.
368
+ TypePtr keyType() const;
369
+ TypePtr valueType() const;
370
+
371
+ // [unsafe set type]
372
+ // These functions mutate the tagged type of this dictionary in place.
373
+ // There is no checking that the members of the dictionary are instances
374
+ // of the new types, nor is there a check that other IValues which
375
+ // hold references to this dictionary have the right static type.
376
+ // This functionality is used only in the unpickler, where at
377
+ // creation type the real type of the dictionary is unknown, but
378
+ // then later recovered from the static type information of the
379
+ // unpickled object.
380
+ void unsafeSetKeyType(TypePtr t);
381
+ void unsafeSetValueType(TypePtr t);
382
+ };
383
+
384
+ namespace impl {
385
+ // GenericDict is how IValue stores dicts. It is, however, not part of the
386
+ // public API. Kernels should use Dicts with concrete Key, Value types instead
387
+ // (maybe except for some internal prim ops).
388
+ using GenericDict = Dict<IValue, IValue>;
389
+
390
+ }
391
+ }
392
+
393
+ namespace torch {
394
+ template<class Key, class Value> using Dict = c10::Dict<Key, Value>;
395
+ }
396
+
397
+ #include <ATen/core/Dict_inl.h> // IWYU pragma: keep
398
+
399
+ #else
400
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
401
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/ivalue.h>
5
+ #include <c10/util/hash.h>
6
+
7
+ namespace c10 {
8
+ namespace detail {
9
+ inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
10
+ if (lhs.isTensor() && rhs.isTensor()) {
11
+ // for tensors, we compare only by identity (following how it's done in Python).
12
+ return lhs.is(rhs);
13
+ }
14
+ // Otherwise, we first compare by identity for efficiency, then by value (see:
15
+ // [container equality])
16
+ return _fastEqualsForContainer(lhs, rhs);
17
+ }
18
+ }
19
+
20
+ template<class T> decltype(auto) getTypePtr();
21
+ std::string toString(const Type& type);
22
+
23
+ namespace impl {
24
+
25
+ template<class Key, class Value>
26
+ Dict<Key, Value> toTypedDict(GenericDict dict) {
27
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
28
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
29
+
30
+ return Dict<Key, Value>(std::move(dict.impl_));
31
+ }
32
+
33
+ template<class Key, class Value>
34
+ GenericDict toGenericDict(Dict<Key, Value> dict) {
35
+ return GenericDict(std::move(dict.impl_));
36
+ }
37
+ }
38
+
39
+ namespace detail {
40
+
41
+ inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
42
+ if (ivalue.isInt()) {
43
+ return std::hash<int64_t>()(ivalue.toInt());
44
+ } else if (ivalue.isString()) {
45
+ return std::hash<std::string_view>()(ivalue.toStringView());
46
+ } else if (ivalue.isDouble()) {
47
+ return std::hash<double>()(ivalue.toDouble());
48
+ } else if (ivalue.isComplexDouble()) {
49
+ return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
50
+ } else if (ivalue.isBool()) {
51
+ return std::hash<bool>()(ivalue.toBool());
52
+ } else if (ivalue.isTensor()) {
53
+ return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
54
+ } else if (ivalue.isDevice()) {
55
+ return std::hash<Device>()(ivalue.toDevice());
56
+ } else {
57
+ TORCH_CHECK(false, "Can't hash IValues with tag '", ivalue.tagKind(), "'");
58
+ }
59
+ }
60
+
61
+ inline intrusive_ptr<DictImpl> DictImpl::copy() const {
62
+ return make_intrusive<DictImpl>(dict, elementTypes);
63
+ }
64
+
65
+ }
66
+
67
+ template<class Key, class Value>
68
+ Dict<Key, Value>::Dict()
69
+ :Dict(make_intrusive<detail::DictImpl>(
70
+ detail::DictImpl::dict_map_type(),
71
+ detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
72
+ static_assert(!std::is_same_v<Key, IValue>, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
73
+ static_assert(!std::is_same_v<Value, IValue>, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
74
+ }
75
+
76
+ template<class Key, class Value>
77
+ Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
78
+ : Dict(make_intrusive<detail::DictImpl>(
79
+ detail::DictImpl::dict_map_type(),
80
+ detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
81
+ static_assert(std::is_same_v<Key, IValue>, "This constructor is only valid for c10::impl::GenericDict.");
82
+ static_assert(std::is_same_v<Value, IValue>, "This constructor is only valid for c10::impl::GenericDict.");
83
+ }
84
+
85
+ template<class Key, class Value>
86
+ Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
87
+
88
+ template<class Key, class Value>
89
+ Dict<Key, Value> Dict<Key, Value>::copy() const {
90
+ return Dict<Key, Value>(impl_->copy());
91
+ }
92
+
93
+ template<class Key, class Value>
94
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
95
+ return iterator{impl_->dict.begin()};
96
+ }
97
+
98
+ template<class Key, class Value>
99
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
100
+ return iterator{impl_->dict.end()};
101
+ }
102
+
103
+ template<class Key, class Value>
104
+ bool Dict<Key, Value>::empty() const {
105
+ return impl_->dict.empty();
106
+ }
107
+
108
+ template<class Key, class Value>
109
+ typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
110
+ return impl_->dict.size();
111
+ }
112
+
113
+ template<class Key, class Value>
114
+ void Dict<Key, Value>::clear() const {
115
+ impl_->dict.clear();
116
+ }
117
+
118
+ template<class Key, class Value>
119
+ template<class Key_, class Value_>
120
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
121
+ static_assert(std::is_constructible_v<Key, Key_>, "Wrong type for the key argument of Dict::insert");
122
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of Dict::insert");
123
+ auto inserted = impl_->dict.emplace(
124
+ Key(std::forward<Key_>(key)),
125
+ Value(std::forward<Value_>(value)));
126
+ return {iterator{inserted.first}, inserted.second};
127
+ }
128
+
129
+ template<class Key, class Value>
130
+ template<class Key_, class Value_>
131
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
132
+ static_assert(std::is_constructible_v<Key, Key_>, "Wrong type for the key argument of Dict::insert_or_assign");
133
+ static_assert(std::is_constructible_v<Value, Value_>, "Wrong type for the value argument of Dict::insert_or_assign");
134
+ auto inserted = impl_->dict.insert_or_assign(
135
+ Key(std::forward<Key_>(key)),
136
+ Value(std::forward<Value_>(value)));
137
+ return {iterator{inserted.first}, inserted.second};
138
+ }
139
+
140
+ template<class Key, class Value>
141
+ void Dict<Key, Value>::erase(iterator iter) const {
142
+ impl_->dict.erase(iter.entryRef_.iterator_);
143
+ }
144
+
145
+ template <class Key, class Value>
146
+ [[nodiscard]] size_t Dict<Key, Value>::erase(const Key& key) const {
147
+ return impl_->dict.erase(key);
148
+ }
149
+
150
+ template<class Key, class Value>
151
+ Value Dict<Key, Value>::at(const Key& key) const {
152
+ return impl_->dict.at(key).template to<Value>();
153
+ }
154
+
155
+ template<class Key, class Value>
156
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
157
+ return iterator{impl_->dict.find(key)};
158
+ }
159
+
160
+ template<class Key, class Value>
161
+ bool Dict<Key, Value>::contains(const Key& key) const {
162
+ return end() != find(key);
163
+ }
164
+
165
+ template<class Key, class Value>
166
+ void Dict<Key, Value>::reserve(size_type count) const {
167
+ impl_->dict.reserve(count);
168
+ }
169
+
170
+ template<class Key, class Value>
171
+ TypePtr Dict<Key, Value>::keyType() const {
172
+ return impl_->elementTypes.keyType;
173
+ }
174
+
175
+ template<class Key, class Value>
176
+ TypePtr Dict<Key, Value>::valueType() const {
177
+ return impl_->elementTypes.valueType;
178
+ }
179
+ template <class Key, class Value>
180
+ void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
181
+ impl_->elementTypes.keyType = std::move(t);
182
+ }
183
+
184
+ template <class Key, class Value>
185
+ void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
186
+ impl_->elementTypes.valueType = std::move(t);
187
+ }
188
+
189
+ template <class Key_, class Value_>
190
+ bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
191
+ // Dicts with the same identity trivially compare equal.
192
+ if (lhs.impl_ == rhs.impl_) {
193
+ return true;
194
+ }
195
+
196
+ // Otherwise compare the values
197
+ return *lhs.impl_ == *rhs.impl_;
198
+ }
199
+
200
+ template <class Key_, class Value_>
201
+ bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
202
+ return !(lhs == rhs);
203
+ }
204
+
205
+ template <class Key, class Value>
206
+ bool Dict<Key, Value>::is(const Dict& rhs) const {
207
+ return this->impl_ == rhs.impl_;
208
+ }
209
+ }
210
+
211
+ #else
212
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
213
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/util/DimVector.h>
4
+
5
+ namespace at {
6
+
7
+ // Redeclaring 'DimVector' type and size inside 'at' namespace.
8
+ // This is done to avoid modifying every use into their 'c10'
9
+ // equivalent.
10
+
11
+ using c10::kDimVectorStaticSize;
12
+ using c10::DimVector;
13
+
14
+ } // namespace at
15
+
16
+ #else
17
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
18
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/symbol.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <optional>
7
+ #include <ostream>
8
+
9
+ namespace at {
10
+
11
+ enum class NameType: uint8_t { BASIC, WILDCARD };
12
+
13
+ struct TORCH_API Dimname {
14
+ static Dimname fromSymbol(Symbol name);
15
+ static Dimname wildcard();
16
+ static bool isValidName(const std::string& name);
17
+
18
+ NameType type() const { return type_; }
19
+ Symbol symbol() const { return name_; }
20
+
21
+ bool isBasic() const { return type_ == NameType::BASIC; }
22
+ bool isWildcard() const { return type_ == NameType::WILDCARD; }
23
+
24
+ bool matches(Dimname other) const;
25
+ std::optional<Dimname> unify(Dimname other) const;
26
+
27
+ private:
28
+ Dimname(Symbol name)
29
+ : name_(name), type_(NameType::BASIC) {}
30
+ Dimname(Symbol name, NameType type)
31
+ : name_(name), type_(type) {}
32
+
33
+ Symbol name_;
34
+ NameType type_;
35
+ };
36
+
37
+ using DimnameList = c10::ArrayRef<Dimname>;
38
+
39
+ TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
40
+
41
+ inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
42
+ return lhs.symbol() == rhs.symbol();
43
+ }
44
+
45
+ inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
46
+ return !(lhs == rhs);
47
+ }
48
+
49
+ } // namespace at
50
+
51
+ #else
52
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
53
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/TransformationHelper.h>
5
+ #include <c10/util/Half.h>
6
+ #include <c10/util/BFloat16.h>
7
+ #include <c10/util/MathConstants.h>
8
+ #include <c10/macros/Macros.h>
9
+
10
+ #include <cmath>
11
+ #include <limits>
12
+ #include <optional>
13
+ #include <type_traits>
14
+
15
+ /**
16
+ * Distributions kernel adapted from THRandom.cpp
17
+ * The kernels try to follow std::random distributions signature
18
+ * For instance: in ATen
19
+ * auto gen = at::detail::createCPUGenerator();
20
+ * at::uniform_real_distribution<double> uniform(0, 1);
21
+ * auto sample = uniform(gen.get());
22
+ *
23
+ * vs std::random
24
+ *
25
+ * std::mt19937 gen;
26
+ * std::uniform_real_distribution uniform(0, 1);
27
+ * auto sample = uniform(gen);
28
+ */
29
+
30
+
31
+ namespace at {
32
+ namespace {
33
+
34
+ /**
35
+ * Samples a discrete uniform distribution in the range [base, base+range) of type T
36
+ */
37
+ template <typename T>
38
+ struct uniform_int_from_to_distribution {
39
+
40
+ C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {}
41
+
42
+ template <typename RNG>
43
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
44
+ #ifdef FBCODE_CAFFE2
45
+ if ((
46
+ std::is_same_v<T, int64_t> ||
47
+ std::is_same_v<T, double> ||
48
+ std::is_same_v<T, float> ||
49
+ std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
50
+ #else
51
+ if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
52
+ #endif
53
+ {
54
+ return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
55
+ } else {
56
+ return transformation::uniform_int_from_to<T>(generator->random(), range_, base_);
57
+ }
58
+ }
59
+
60
+ private:
61
+ uint64_t range_;
62
+ int64_t base_;
63
+ };
64
+
65
+ /**
66
+ * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
67
+ */
68
+ template <typename T>
69
+ struct uniform_int_full_range_distribution {
70
+
71
+ template <typename RNG>
72
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
73
+ return transformation::uniform_int_full_range<T>(generator->random64());
74
+ }
75
+
76
+ };
77
+
78
+ /**
79
+ * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
80
+ * and [0, 2^mantissa] for floating-point types.
81
+ */
82
+ template <typename T>
83
+ struct uniform_int_distribution {
84
+
85
+ template <typename RNG>
86
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
87
+ if constexpr (std::is_same_v<T, double> || std::is_same_v<T, int64_t>) {
88
+ return transformation::uniform_int<T>(generator->random64());
89
+ } else {
90
+ return transformation::uniform_int<T>(generator->random());
91
+ }
92
+ }
93
+
94
+ };
95
+
96
+ /**
97
+ * Samples a uniform distribution in the range [from, to) of type T
98
+ */
99
+ template <typename T>
100
+ struct uniform_real_distribution {
101
+
102
+ C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) : from_(from), to_(to) {
103
+ TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
104
+ TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max());
105
+ }
106
+
107
+ template <typename RNG>
108
+ C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
109
+ if constexpr (std::is_same_v<T, double>) {
110
+ return transformation::uniform_real<T>(generator->random64(), from_, to_);
111
+ } else {
112
+ return transformation::uniform_real<T>(generator->random(), from_, to_);
113
+ }
114
+ }
115
+
116
+ private:
117
+ T from_;
118
+ T to_;
119
+ };
120
+
121
+ // The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
122
+ // https://github.com/pytorch/pytorch/issues/40052
123
+ #define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \
124
+ template <typename T> \
125
+ struct has_member_##member \
126
+ { \
127
+ typedef char yes; \
128
+ typedef long no; \
129
+ template <typename U> static yes test(decltype(&U::member)); \
130
+ template <typename U> static no test(...); \
131
+ static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
132
+ }
133
+
134
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
135
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
136
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
137
+ DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
138
+
139
+ #define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \
140
+ \
141
+ template <typename RNG, typename ret_type, \
142
+ typename std::enable_if_t<( \
143
+ has_member_next_##TYPE##_normal_sample<RNG>::value && \
144
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
145
+ ), int> = 0> \
146
+ C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
147
+ if (generator->next_##TYPE##_normal_sample()) { \
148
+ *ret = *(generator->next_##TYPE##_normal_sample()); \
149
+ generator->set_next_##TYPE##_normal_sample(std::optional<TYPE>()); \
150
+ return true; \
151
+ } \
152
+ return false; \
153
+ } \
154
+ \
155
+ template <typename RNG, typename ret_type, \
156
+ typename std::enable_if_t<( \
157
+ !has_member_next_##TYPE##_normal_sample<RNG>::value || \
158
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
159
+ ), int> = 0> \
160
+ C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) { \
161
+ return false; \
162
+ } \
163
+ \
164
+ template <typename RNG, typename ret_type, \
165
+ typename std::enable_if_t<( \
166
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
167
+ ), int> = 0> \
168
+ C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
169
+ generator->set_next_##TYPE##_normal_sample(cache); \
170
+ } \
171
+ \
172
+ template <typename RNG, typename ret_type, \
173
+ typename std::enable_if_t<( \
174
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
175
+ ), int> = 0> \
176
+ C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \
177
+ }
178
+
179
+ DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double)
180
+ DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float)
181
+
182
+ /**
183
+ * Samples a normal distribution using the Box-Muller method
184
+ * Takes mean and standard deviation as inputs
185
+ * Note that Box-muller method returns two samples at a time.
186
+ * Hence, we cache the "next" sample in the CPUGeneratorImpl class.
187
+ */
188
+ template <typename T>
189
+ struct normal_distribution {
190
+
191
+ C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) {
192
+ TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in);
193
+ }
194
+
195
+ template <typename RNG>
196
+ C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
197
+ dist_acctype<T> ret;
198
+ // return cached values if available
199
+ if constexpr (std::is_same_v<T, double>) {
200
+ if (maybe_get_next_double_normal_sample(generator, &ret)) {
201
+ return transformation::normal(ret, mean, stdv);
202
+ }
203
+ } else {
204
+ if (maybe_get_next_float_normal_sample(generator, &ret)) {
205
+ return transformation::normal(ret, mean, stdv);
206
+ }
207
+ }
208
+ // otherwise generate new normal values
209
+ uniform_real_distribution<T> uniform(0.0, 1.0);
210
+ const dist_acctype<T> u1 = uniform(generator);
211
+ const dist_acctype<T> u2 = uniform(generator);
212
+ const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log1p(-u2));
213
+ const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1;
214
+ if constexpr (std::is_same_v<T, double>) {
215
+ maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
216
+ } else {
217
+ maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
218
+ }
219
+ ret = r * ::cos(theta);
220
+ return transformation::normal(ret, mean, stdv);
221
+ }
222
+
223
+ private:
224
+ T mean;
225
+ T stdv;
226
+ };
227
+
228
+ template <typename T>
229
+ struct DiscreteDistributionType { using type = float; };
230
+
231
+ template <> struct DiscreteDistributionType<double> { using type = double; };
232
+
233
+ /**
234
+ * Samples a bernoulli distribution given a probability input
235
+ */
236
+ template <typename T>
237
+ struct bernoulli_distribution {
238
+
239
+ C10_HOST_DEVICE inline bernoulli_distribution(T p_in) : p(p_in) {
240
+ TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1);
241
+ }
242
+
243
+ template <typename RNG>
244
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
245
+ uniform_real_distribution<T> uniform(0.0, 1.0);
246
+ return transformation::bernoulli<T>(uniform(generator), p);
247
+ }
248
+
249
+ private:
250
+ T p;
251
+ };
252
+
253
+ /**
254
+ * Samples a geometric distribution given a probability input
255
+ */
256
+ template <typename T>
257
+ struct geometric_distribution {
258
+
259
+ C10_HOST_DEVICE inline geometric_distribution(T p_in) : p(p_in) {
260
+ TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1);
261
+ }
262
+
263
+ template <typename RNG>
264
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
265
+ uniform_real_distribution<T> uniform(0.0, 1.0);
266
+ return transformation::geometric<T>(uniform(generator), p);
267
+ }
268
+
269
+ private:
270
+ T p;
271
+ };
272
+
273
+ /**
274
+ * Samples an exponential distribution given a lambda input
275
+ */
276
+ template <typename T>
277
+ struct exponential_distribution {
278
+
279
+ C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {}
280
+
281
+ template <typename RNG>
282
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
283
+ uniform_real_distribution<T> uniform(0.0, 1.0);
284
+ return transformation::exponential<T>(uniform(generator), lambda);
285
+ }
286
+
287
+ private:
288
+ T lambda;
289
+ };
290
+
291
+ /**
292
+ * Samples a cauchy distribution given median and sigma as inputs
293
+ */
294
+ template <typename T>
295
+ struct cauchy_distribution {
296
+
297
+ C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {}
298
+
299
+ template <typename RNG>
300
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
301
+ uniform_real_distribution<T> uniform(0.0, 1.0);
302
+ return transformation::cauchy<T>(uniform(generator), median, sigma);
303
+ }
304
+
305
+ private:
306
+ T median;
307
+ T sigma;
308
+ };
309
+
310
+ /**
311
+ * Samples a lognormal distribution
312
+ * Takes mean and standard deviation as inputs
313
+ * Outputs two samples at a time
314
+ */
315
+ template <typename T>
316
+ struct lognormal_distribution {
317
+
318
+ C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) {
319
+ TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
320
+ }
321
+
322
+ template<typename RNG>
323
+ C10_HOST_DEVICE inline T operator()(RNG generator){
324
+ normal_distribution<T> normal(mean, stdv);
325
+ return transformation::log_normal<T>(normal(generator));
326
+ }
327
+
328
+ private:
329
+ T mean;
330
+ T stdv;
331
+ };
332
+ }
333
+ } // namespace at
334
+
335
+ #else
336
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
337
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ostream>
5
+ #include <string>
6
+
7
+ #include <c10/core/Scalar.h>
8
+ #include <ATen/core/Tensor.h>
9
+
10
+ namespace c10 {
11
+ TORCH_API std::ostream& operator<<(std::ostream& out, Backend b);
12
+ TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s);
13
+ TORCH_API std::string toString(const Scalar& s);
14
+ }
15
+ namespace at {
16
+
17
+ TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
18
+ TORCH_API std::ostream& print(
19
+ std::ostream& stream,
20
+ const Tensor& tensor,
21
+ int64_t linesize);
22
+ inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
23
+ return print(out,t,80);
24
+ }
25
+ TORCH_API void print(const Tensor & t, int64_t linesize=80);
26
+ }
27
+
28
+ #else
29
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
30
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+ #include <deque>
6
+ #include <mutex>
7
+ #include <utility>
8
+
9
+ #include <c10/util/Exception.h>
10
+ #include <c10/util/intrusive_ptr.h>
11
+ #include <c10/core/Device.h>
12
+ #include <c10/core/DispatchKeySet.h>
13
+
14
+ // For the record I don't think this is a correct pimpl idiom.
15
+ // Including Impl header in interface header defeats the purpose
16
+ // because you can't change Impl private members without forcing
17
+ // everything that included the interface to rebuild.
18
+ // Impl should be forward-declared in the interface header instead.
19
+ #include <c10/core/GeneratorImpl.h>
20
+
21
+ /**
22
+ * Note [Generator]
23
+ * ~~~~~~~~~~~~~~~~
24
+ * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
25
+ * generate a seemingly random sequence of numbers, that may be later be used in creating
26
+ * a random distribution. Such an engine almost always maintains a state and requires a
27
+ * seed to start off the creation of random numbers. Often times, users have
28
+ * found it beneficial to be able to explicitly create, retain, and destroy
29
+ * PRNG states and also be able to have control over the seed value.
30
+ *
31
+ * A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
32
+ * For instance, it does so by letting users seed a PRNG engine, fork the state of the
33
+ * engine, etc.
34
+ *
35
+ * By default, there is one generator per device, and a device's generator is
36
+ * lazily created. A user can use the torch.Generator() api to create their own generator.
37
+ */
38
+
39
+ /**
40
+ * Note [Acquire lock when using random generators]
41
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42
+ * Generator and its derived classes are NOT thread-safe. Please note that most of the
43
+ * places where we have inserted locking for generators are historically based, and we
44
+ * haven't actually checked that everything is truly thread safe (and it probably isn't).
45
+ * Please use the public mutex_ when using any methods from these classes, except for the
46
+ * read-only methods. You can learn about the usage by looking into the unittests
47
+ * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
48
+ *
49
+ * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
50
+ * them non-thread safe and instead making the generator state splittable, to accommodate
51
+ * forks into other threads).
52
+ */
53
+
54
+ namespace at {
55
+
56
+ class Tensor;
57
+
58
+ struct TORCH_API Generator {
59
+ Generator() = default;
60
+
61
+ explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
62
+ : impl_(std::move(gen_impl)) {
63
+ TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
64
+ }
65
+
66
+ bool operator==(const Generator& rhs) const {
67
+ return this->impl_ == rhs.impl_;
68
+ }
69
+
70
+ bool operator!=(const Generator& rhs) const {
71
+ return !((*this) == rhs);
72
+ }
73
+
74
+ bool defined() const {
75
+ return static_cast<bool>(impl_);
76
+ }
77
+
78
+ c10::GeneratorImpl* unsafeGetGeneratorImpl() const {
79
+ return impl_.get();
80
+ }
81
+
82
+ c10::GeneratorImpl* unsafeReleaseGeneratorImpl() {
83
+ return impl_.release();
84
+ }
85
+
86
+ const c10::intrusive_ptr<c10::GeneratorImpl>& getIntrusivePtr() const {
87
+ return impl_;
88
+ }
89
+
90
+ void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
91
+ // Sets the offset of Generator state to the desired offset. This is currently
92
+ // supported for only Philox based Generators, i.e., CUDA and MPS.
93
+ void set_offset(uint64_t offset) { impl_->set_offset(offset); }
94
+
95
+ // Returns the offset of Generator state. This is currently supported for only
96
+ // Philox based Generators, i.e., CUDA and MPS.
97
+ uint64_t get_offset() const { return impl_->get_offset(); }
98
+
99
+ uint64_t current_seed() const { return impl_->current_seed(); }
100
+
101
+ uint64_t seed() { return impl_->seed(); }
102
+
103
+ // Implementation not inlined to prevent cycle reference between
104
+ // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
105
+ void set_state(const at::Tensor& new_state);
106
+
107
+ at::Tensor get_state() const;
108
+
109
+ void graphsafe_set_state(const Generator& new_state);
110
+
111
+ Generator graphsafe_get_state() const;
112
+
113
+ std::mutex& mutex() {
114
+ return impl_->mutex_;
115
+ }
116
+
117
+ DispatchKeySet key_set() const {
118
+ return impl_->key_set();
119
+ }
120
+
121
+ Device device() const { return impl_->device(); }
122
+
123
+ inline void set_pyobj(PyObject* pyobj) const noexcept {
124
+ impl_->set_pyobj(pyobj);
125
+ }
126
+
127
+ inline PyObject* pyobj() const noexcept {
128
+ return impl_->pyobj();
129
+ }
130
+
131
+ template<typename T>
132
+ T* get() const { return static_cast<T*>(impl_.get()); }
133
+
134
+ Generator clone() const {
135
+ return Generator(impl_->clone());
136
+ }
137
+
138
+ private:
139
+ c10::intrusive_ptr<c10::GeneratorImpl> impl_;
140
+ };
141
+
142
+ template<class Impl, class... Args>
143
+ Generator make_generator(Args&&... args) {
144
+ return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
145
+ }
146
+
147
+ /**
148
+ * Utility function to static cast input Generator* to
149
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
150
+ */
151
+ template <typename T>
152
+ inline T * check_generator(std::optional<Generator> gen) {
153
+ TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
154
+ TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
155
+ TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
156
+ return gen->get<T>();
157
+ }
158
+
159
+ /**
160
+ * Utility function used in tensor implementations, which
161
+ * supplies the default generator to tensors, if an input generator
162
+ * is not supplied. The input Generator* is also static casted to
163
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
164
+ */
165
+ template <typename T>
166
+ inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
167
+ return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
168
+ }
169
+
170
+ namespace detail {
171
+
172
+ /**
173
+ * Helper function for checking the validity of new random generator
174
+ * state. Right now following conditions are checked:
175
+ *
176
+ * - The new state tensor must be a torch.ByteTensor
177
+ * - Data of the new state tensor must be contiguous
178
+ */
179
+ inline void check_rng_state(const c10::TensorImpl& new_state) {
180
+ TORCH_CHECK_TYPE(
181
+ new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
182
+ "RNG state must be a torch.ByteTensor"
183
+ );
184
+
185
+ TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
186
+ }
187
+
188
+ } // namespace detail
189
+
190
+ } // namespace at
191
+
192
+ #else
193
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
194
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Generator.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace at {
8
+
9
+ using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
10
+
11
+ TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
12
+
13
+ class TORCH_API _GeneratorRegister {
14
+ public:
15
+ explicit _GeneratorRegister(const GeneratorFuncType& func);
16
+ };
17
+
18
+ TORCH_API at::Generator GetGeneratorForPrivateuse1(
19
+ c10::DeviceIndex device_index);
20
+
21
+ /**
22
+ * This is used to register Generator to PyTorch for `privateuse1` key.
23
+ *
24
+ * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
25
+ *
26
+ * class CustomGeneratorImpl : public c10::GeneratorImpl {
27
+ * CustomGeneratorImpl(DeviceIndex device_index = -1);
28
+ * explicit ~CustomGeneratorImpl() override = default;
29
+ * ...
30
+ * };
31
+ *
32
+ * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
33
+ * return at::make_generator<CustomGeneratorImpl>(id);
34
+ * }
35
+ */
36
+
37
+ #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
38
+ static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
39
+
40
+ } // namespace at
41
+
42
+ #else
43
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
44
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/ivalue_to.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <c10/util/Exception.h>
7
+
8
+ #include <functional>
9
+ #include <initializer_list>
10
+ #include <iterator>
11
+ #include <type_traits>
12
+
13
+ /*
14
+ * [Note: IListRef]
15
+ * Wrapper around different API containers (e.g. boxed and unboxed).
16
+ *
17
+ * What is it?
18
+ * ===========
19
+ * It is a tagged union of both boxed and unboxed API containers.
20
+ * Working implementations:
21
+ *
22
+ * - `IListRef<at::Tensor>`
23
+ * - `IListRef<at::OptionalTensorRef>`
24
+ *
25
+ * Note that `IListRef` is a view type. Meaning that it won't own the
26
+ * tensors it holds. It's intended to be used only as argument parameters.
27
+ * Specifically, where these 2 worlds overlap.
28
+ *
29
+ * What is this for?
30
+ * =================
31
+ * Historically, PyTorch has maintained 2 different APIs: the unboxed
32
+ * (called from C++ API and Python eager mode) and boxed APIs (called
33
+ * from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
34
+ *
35
+ * Calling unboxed kernels from the boxed "world" and vice-versa may
36
+ * result in non-negligible overhead. Lists are one of those types:
37
+ *
38
+ * - Boxed world: `c10::List`
39
+ * - Unboxed world: `c10::ArrayRef`
40
+ *
41
+ * In this context, `c10::IListRef` solves this problem by wrapping those
42
+ * 2 container types, so that we don't need to convert from one to
43
+ * the other.
44
+ *
45
+ * (see https://github.com/pytorch/pytorch/issues/66328)
46
+ *
47
+ * What does it do?
48
+ * ================
49
+ * This container wraps around the different tagged containers
50
+ * (currently, only boxed and unboxed), without incurring in extra
51
+ * overhead for converting from one to another. It does so while
52
+ * exposing usual container methods, which dispatch to corresponding
53
+ * implementations.
54
+ *
55
+ * While it works with different container types, it introduces
56
+ * overhead for repeatedly calling member functions (since those will
57
+ * get dispatched, again). Therefore, you should only use it to iterate
58
+ * through the list up to one time. If you need to do more complex things,
59
+ * call `materialize()` first.
60
+ *
61
+ * Adding support for a new Tag
62
+ * ============================
63
+ * Suppose we want to add a new tag: `Chest`. Here are the steps
64
+ * we would have to go through:
65
+ *
66
+ * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
67
+ *
68
+ * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
69
+ * ...
70
+ * _(Chest, ##__VA_ARGS__)
71
+ *
72
+ * 2. Add type aliases, union members, and constructors.
73
+ *
74
+ * template <typename T>
75
+ * class IListRef {
76
+ * ...
77
+ * using chest_type =
78
+ * typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type;
79
+ * ...
80
+ * IListRef(...) : tag_(IListRefTag::Chest) {
81
+ * ...
82
+ * }
83
+ * ...
84
+ * union Payload {
85
+ * ...
86
+ * chest_type chest;
87
+ * ...
88
+ * };
89
+ * ...
90
+ * };
91
+ *
92
+ * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
93
+ * preferable to make the default implementation work for `T = Tensor`
94
+ * (both `Unboxed` and `Boxed` do it).
95
+ *
96
+ * template <typename T, typename ListElemT>
97
+ * class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> {
98
+ * public:
99
+ * using elem_type = ListElemT;
100
+ * using list_type = ChestContainer<elem_type>;
101
+ *
102
+ * static const list_type& unwrap(const IListRef<T>& ilist) { ... }
103
+ *
104
+ * static typename list_type::const_iterator& unwrap(
105
+ * IListRefIterator<T>& it) { ... }
106
+ *
107
+ * static const typename list_type::const_iterator& unwrap(
108
+ * const IListRefIterator<T>& it) { ... }
109
+ *
110
+ * static IListRefConstRef<T> iterator_get(
111
+ * const typename list_type::const_iterator& it) { ... }
112
+ * }
113
+ *
114
+ * 4. Add an specialization for each of the already supported types.
115
+ * Finally, for consistency, add them to the tracking list.
116
+ * (see [Note: IListRefTagImpl Specializations])
117
+ *
118
+ * template <>
119
+ * class IListRefTagImpl<IListRefTag::Chest, at::Tensor>
120
+ * : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {};
121
+ *
122
+ * Adding support for a new Type
123
+ * =============================
124
+ * Suppose we want to add support for a new type: `Matrix`.
125
+ * Here are the steps we would have to go through:
126
+ *
127
+ * 1. Add an specialization for each of the existing tags.
128
+ * For consistency, add them to the tracking list.
129
+ * (see [Note: IListRefTagImpl Specializations])
130
+ *
131
+ * template <>
132
+ * class IListRefTagImpl<IListRefTag::Unboxed, Matrix>
133
+ * : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {};
134
+ *
135
+ * template <>
136
+ * class IListRefTagImpl<Matrix, IListRefTag::Boxed>
137
+ * : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {};
138
+ *
139
+ * Common Problems
140
+ * ===============
141
+ * 1. One of `IListRef(Iterator)` methods are failing to compile.
142
+ *
143
+ * That may be happening because the container type you added
144
+ * is not compatible with the code written for that method. If
145
+ * that's true, then you might have to transform that code into
146
+ * a static method call (see `List::operator[]` method).
147
+ *
148
+ * 2. Can't make `IListRefIterator<T>::operator*` return a const-reference.
149
+ *
150
+ * First, keep in mind that we assume that boxed containers will
151
+ * have to deal with `IValue` (e.g. `c10::List`). In this context,
152
+ * what may be happening is that `IValue` doesn't store internally
153
+ * your type `T`. Instead, it constructs a type new `T` every time
154
+ * you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
155
+ */
156
+
157
+ namespace c10 {
158
+ template <typename T>
159
+ class IListRef;
160
+
161
+ /*
162
+ * Applies arbitrary macros to each `IListRefTag`.
163
+ */
164
+ #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
165
+ _(Unboxed, ##__VA_ARGS__) \
166
+ _(Boxed, ##__VA_ARGS__) \
167
+ _(Materialized, ##__VA_ARGS__)
168
+
169
+ /*
170
+ * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
171
+ * while bringing to scope:
172
+ *
173
+ * - `ImplT`: the implementation class for `TAG`
174
+ * - `this_`: the result of unwrapping `this`
175
+ */
176
+ #define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \
177
+ case c10::IListRefTag::TAG: { \
178
+ using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \
179
+ auto& this_ = ImplT::unwrap(*this); \
180
+ BODY \
181
+ } break;
182
+
183
+ /*
184
+ * Dispatches the unwrap call, depending on `TAG`, followed by
185
+ * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
186
+ *
187
+ * This macro is useful because it allows us to handle different
188
+ * types (that correspond to different tags) to be implemented
189
+ * only once. We can do it even when the implementation of the
190
+ * different tags aren't syntactically the same, by dispatching
191
+ * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
192
+ */
193
+ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
194
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
195
+ switch (TAG) { \
196
+ TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
197
+ break; \
198
+ default: \
199
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
200
+ } \
201
+ C10_DIAGNOSTIC_POP()
202
+
203
+ enum class IListRefTag {
204
+ #define DEFINE_TAG(tag, ...) tag,
205
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
206
+ #undef DEFINE_TAG
207
+ None
208
+ };
209
+
210
+ namespace detail {
211
+ /*
212
+ * Type alias that specifies whether we return a reference or a copy of `T`.
213
+ *
214
+ * What is this for?
215
+ * =================
216
+ * Since values in the boxed world are represented by an `IValue`, we also
217
+ * depend on whether it can be converted to a const-reference (`Tensor`) or
218
+ * has to create a new copy of `T` (`OptionalTensorRef`).
219
+ */
220
+ template <typename T>
221
+ using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type;
222
+
223
+ /*
224
+ * Interface that implements key functions for each `IListRefTag` type.
225
+ *
226
+ * What is this for?
227
+ * =================
228
+ * Given an `IListRef(Iterator)<T>`, some methods have to be implemented
229
+ * differently for each `TAG`. Therefore, the methods inside this class
230
+ * are used as dispatch targets for the different `IListRefTag` values.
231
+ *
232
+ * You should create an specialization of this class for each possible
233
+ * combination of `IListRefTag` type (except `None`) and element types
234
+ * (e.g. `Tensor`).
235
+ *
236
+ * What does it do?
237
+ * ================
238
+ * 1. defines static methods to be used as dispatch targets by both
239
+ * `IListRef<T>` and `IListRefIterator<T>` (see the implementation of
240
+ * `IListRefTagImplBase`).
241
+ *
242
+ * 2. defines the `elem_type` and `list_type` aliases that will be
243
+ * used in the definition of `IListRef<T>`. In general, we should do
244
+ * so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`.
245
+ *
246
+ * [Note: IListRefTagImpl Specialization]
247
+ * ======================================
248
+ * For `IListRef(Iterator)<at::Tensor>`:
249
+ * - <IListRefTag::Unboxed, at::Tensor>
250
+ * - <IListRefTag::Boxed, at::Tensor>
251
+ * - <IListRefTag::Materialized, at::Tensor>
252
+ *
253
+ * For `IListRef(Iterator)<at::OptionalTensorRef>`:
254
+ * - <IListRefTag::Unboxed, at::OptionalTensorRef>
255
+ * - <IListRefTag::Boxed, at::OptionalTensorRef>
256
+ * - <IListRefTag::Materialized, at::OptionalTensorRef>
257
+ */
258
+ template <IListRefTag TAG, typename T>
259
+ class IListRefTagImpl {};
260
+
261
+ /*
262
+ * Base implementation of `IListRefTagImpl<TAG, T>` methods.
263
+ *
264
+ * What is this for?
265
+ * =================
266
+ * This should make adding specializations for new types easier. For
267
+ * example, one should be able to add a new type just by making its
268
+ * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
269
+ *
270
+ * You should create a partial specialization for this class only if
271
+ * you introduce a new `IListRefTag`. The idea being that there is one
272
+ * default implementation for each possible value of `IListRefTag`.
273
+ *
274
+ * What does it do?
275
+ * ================
276
+ * 1. defines `elem_type` as an alias to `ListElemT`.
277
+ *
278
+ * 1. defines `list_type` as an alias to the default container type
279
+ * that will hold a collection of `elem_type`. The idea being that
280
+ * all types tagged as `TAG` will have `list_type` as its container,
281
+ * with different `elem_type`.
282
+ *
283
+ * 3. defines the default implementation for each of the methods that
284
+ * are supposed to be defined on `IListRefTagImpl` specializations.
285
+ *
286
+ * 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means
287
+ * that the payload of the type `IListRef<T>` will be of type `list_type`
288
+ * when it is tagged as `TAG`.
289
+ */
290
+ template <IListRefTag TAG, typename T, typename ListElemT = T>
291
+ class IListRefTagImplBase {};
292
+
293
+ /*
294
+ * Materialized container for `IListRef<T>`.
295
+ *
296
+ * What is this for?
297
+ * =================
298
+ * Container that groups `T` references together. This exchanges the
299
+ * overhead of every method call from `IListRef<T>` for a dynamic allocation.
300
+ *
301
+ * You should use this container instead of `IListRef<T>` if:
302
+ *
303
+ * - You are going to iterate the list more than once
304
+ * - You need to repeatedly access arbitrary elements (using `operator[]`)
305
+ * What does it do?
306
+
307
+ * ================
308
+ * Removes the reference (&) from the type, and wraps it into a
309
+ * `std::reference_wrapper`. If `IListRefConstRef<T>` is not a
310
+ * reference type, then it's left unchanged.
311
+ */
312
+ template <typename T>
313
+ using _MaterializedIListRefElem = std::conditional_t<
314
+ std::is_reference_v<T>,
315
+ typename std::reference_wrapper<std::remove_reference_t<T>>,
316
+ T>;
317
+
318
+ template <typename T>
319
+ using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>;
320
+
321
+ template <typename T>
322
+ using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>;
323
+
324
+ } // namespace detail
325
+
326
+ /*
327
+ * Iterator for `IListRef<T>`.
328
+ *
329
+ * What is it?
330
+ * ===========
331
+ * Currently, a `std::bidirectional_iterator` that wraps the iterator
332
+ * types defined for each of the `IListRefTag`.
333
+ *
334
+ * One should be able to use it, as if it were the unwrapped
335
+ * iterators themselves.
336
+
337
+ * What does it do?
338
+ * ================
339
+ * Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it
340
+ * wraps each container's `const_iterator` type alias. So, for example,
341
+ * given that the container for `IListRefTag::Boxed` is `c10::List`, this
342
+ * iterator will wrap a `c10::List::const_iterator`.
343
+ *
344
+ * [Note: MSVC Iterator Debug]
345
+ * ===========================
346
+ * MSVC `vector<T>::iterator` implementation (used in the boxed variant)
347
+ * makes it so this union's destructor, copy-constructor (assignment), and
348
+ * move-constructor (assignment) are implicitly deleted.
349
+ *
350
+ * Therefore, we need to explicitly define them as needed. Follows a list
351
+ * of places where these are needed and their reason:
352
+ *
353
+ * - `Payload` destructor:
354
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
355
+ *
356
+ * - `IListRefIterator` destructor:
357
+ * same as above. However, we need to explicitly call the variant
358
+ * destructor explicitly.
359
+ *
360
+ * - `IListRefIterator` copy-constructor:
361
+ * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
362
+ * than 0.
363
+ */
364
+ template <typename T>
365
+ class IListRefIterator {
366
+ private:
367
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
368
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
369
+ friend class detail::IListRefTagImplBase< \
370
+ IListRefTag::TAG, \
371
+ T, \
372
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
373
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
374
+ #undef DEFINE_FRIEND_CLASS
375
+
376
+ public:
377
+ // C++17 friendly std::iterator implementation
378
+ using iterator_category = std::bidirectional_iterator_tag;
379
+ using value_type = T;
380
+ using difference_type = std::ptrdiff_t;
381
+ using pointer = T*;
382
+ using reference = T&;
383
+
384
+ using unboxed_iterator_type = typename detail::
385
+ IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator;
386
+ using boxed_iterator_type = typename detail::
387
+ IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator;
388
+ using materialized_iterator_type =
389
+ typename detail::MaterializedIListRef<T>::const_iterator;
390
+
391
+ IListRefIterator() : tag_(IListRefTag::None) {}
392
+
393
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
394
+ // See [Note: MSVC Iterator Debug]
395
+ IListRefIterator(const IListRefIterator& iterator)
396
+ : tag_(iterator.tag_) {
397
+ switch (tag_) {
398
+ case IListRefTag::Boxed:
399
+ payload_.boxed_iterator = iterator.payload_.boxed_iterator;
400
+ break;
401
+ case IListRefTag::Unboxed:
402
+ payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
403
+ break;
404
+ case IListRefTag::Materialized:
405
+ payload_.materialized_iterator = iterator.payload_.materialized_iterator;
406
+ break;
407
+ default:
408
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
409
+ }
410
+ }
411
+ #endif
412
+
413
+ #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
414
+ // See [Note: MSVC Iterator Debug]
415
+ ~IListRefIterator() noexcept(false) {
416
+ switch (tag_) {
417
+ case IListRefTag::Boxed:
418
+ payload_.boxed_iterator.~boxed_iterator_type();
419
+ break;
420
+ case IListRefTag::Unboxed:
421
+ payload_.unboxed_iterator.~unboxed_iterator_type();
422
+ break;
423
+ case IListRefTag::Materialized:
424
+ payload_.materialized_iterator.~materialized_iterator_type();
425
+ break;
426
+ default:
427
+ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
428
+ }
429
+ }
430
+ #endif
431
+
432
+ IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
433
+ payload_.boxed_iterator = boxed;
434
+ }
435
+
436
+ IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
437
+ payload_.unboxed_iterator = unboxed;
438
+ }
439
+
440
+ IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
441
+ payload_.materialized_iterator = materialized;
442
+ }
443
+
444
+ detail::IListRefConstRef<T> operator*() const {
445
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
446
+ }
447
+
448
+ IListRefIterator& operator++() {
449
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
450
+ return *this;
451
+ }
452
+
453
+ IListRefIterator operator++(int) {
454
+ auto old = *this;
455
+ TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
456
+ return old;
457
+ }
458
+
459
+ IListRefIterator& operator--() {
460
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
461
+ return *this;
462
+ }
463
+
464
+ IListRefIterator operator--(int) {
465
+ auto old = *this;
466
+ TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
467
+ return old;
468
+ }
469
+
470
+ bool operator==(const IListRefIterator& rhs) const {
471
+ if (tag_ != rhs.tag_) {
472
+ return false;
473
+ }
474
+ TORCH_ILISTREF_UNWRAP(tag_, {
475
+ auto& rhs_it = ImplT::unwrap(rhs);
476
+ return this_ == rhs_it;
477
+ });
478
+ }
479
+
480
+ bool operator!=(const IListRefIterator& rhs) const {
481
+ return !(*this == rhs);
482
+ }
483
+
484
+ private:
485
+ union Payload {
486
+ boxed_iterator_type boxed_iterator;
487
+ unboxed_iterator_type unboxed_iterator;
488
+ materialized_iterator_type materialized_iterator;
489
+ void* _init_ptr;
490
+ Payload() : _init_ptr(nullptr) {}
491
+ #if defined(_MSC_VER)
492
+ // See [Note: MSVC Iterator Debug]
493
+ ~Payload() {}
494
+ #endif
495
+ };
496
+
497
+ Payload payload_;
498
+ IListRefTag tag_;
499
+ };
500
+
501
+ /*
502
+ * See [Note: IListRef]
503
+ */
504
+ template <typename T>
505
+ class IListRef {
506
+ private:
507
+ #define DEFINE_FRIEND_CLASS(TAG, ...) \
508
+ friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
509
+ friend class detail::IListRefTagImplBase< \
510
+ IListRefTag::TAG, \
511
+ T, \
512
+ typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
513
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
514
+ #undef DEFINE_FRIEND_CLASS
515
+
516
+ public:
517
+ using unboxed_type =
518
+ typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type;
519
+ using boxed_type =
520
+ typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type;
521
+ using materialized_type =
522
+ typename detail::MaterializedIListRef<T>;
523
+
524
+ using iterator = IListRefIterator<T>;
525
+ using const_iterator = IListRefIterator<T>;
526
+ using reverse_iterator = std::reverse_iterator<iterator>;
527
+ using value_type = typename iterator::value_type;
528
+
529
+ IListRef() : tag_(IListRefTag::None) {}
530
+
531
+ IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
532
+ payload_.boxed = &boxed;
533
+ }
534
+
535
+ IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
536
+ payload_.unboxed = unboxed;
537
+ }
538
+
539
+ IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) {
540
+ payload_.unboxed = at::ArrayRef<T>(list);
541
+ }
542
+
543
+ template <
544
+ typename... UnboxedConstructorArgs,
545
+ typename = std::enable_if_t<
546
+ std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>>
547
+ IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
548
+ payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...);
549
+ }
550
+
551
+ IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
552
+ payload_.materialized = &materialized;
553
+ }
554
+
555
+ size_t size() const {
556
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
557
+ }
558
+
559
+ bool empty() const {
560
+ return size() == 0;
561
+ }
562
+
563
+ iterator begin() const {
564
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
565
+ }
566
+
567
+ iterator end() const {
568
+ TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
569
+ }
570
+
571
+ detail::IListRefConstRef<T> front() const {
572
+ TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
573
+ }
574
+
575
+ /*
576
+ * Materializes the `IListRef` into a `std::vector`.
577
+ *
578
+ * This should be used when one wishes to either:
579
+ *
580
+ * - iterate over the list more than once: each `IListRefIterator`
581
+ * member function call has to go through a switch, introducing
582
+ * non-negligible overhead
583
+ *
584
+ * - randomly access an arbitrary element using `operator[]`:
585
+ * same reason as above
586
+ */
587
+ detail::MaterializedIListRef<T> materialize() const {
588
+ if (isMaterialized()) {
589
+ return toMaterialized();
590
+ }
591
+
592
+ detail::MaterializedIListRef<T> materialized;
593
+ materialized.reserve(size());
594
+ for (const auto& t : *this) {
595
+ materialized.emplace_back(t);
596
+ }
597
+ return materialized;
598
+ }
599
+
600
+ #define DEFINE_CHECK(TAG, ...) \
601
+ bool is##TAG() const { \
602
+ return tag_ == IListRefTag::TAG; \
603
+ }
604
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK)
605
+ #undef DEFINE_CHECK
606
+
607
+ bool isNone() const {
608
+ return tag_ == IListRefTag::None;
609
+ }
610
+
611
+ #define DEFINE_CASTING(TAG, ...) \
612
+ const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \
613
+ to##TAG() const { \
614
+ TORCH_INTERNAL_ASSERT(is##TAG()); \
615
+ return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this); \
616
+ }
617
+ TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING)
618
+ #undef DEFINE_CASTING
619
+
620
+ private:
621
+ union Payload {
622
+ const boxed_type* boxed;
623
+ unboxed_type unboxed;
624
+ const materialized_type* materialized;
625
+ Payload() : boxed(nullptr) {}
626
+ };
627
+
628
+ Payload payload_;
629
+ IListRefTag tag_;
630
+ };
631
+
632
+ } // namespace c10
633
+
634
+ #include <ATen/core/IListRef_inl.h>
635
+
636
+ #else
637
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
638
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/List.h>
5
+ #include <ATen/core/Tensor.h>
6
+
7
+ namespace at {
8
+ class Tensor;
9
+ class OptionalTensorRef;
10
+ }
11
+
12
+
13
+ namespace c10::detail {
14
+
15
+ /*
16
+ * Specializations of `IListRefTagImplBase` that implement the default
17
+ * implementation for `IListRefTag::Unboxed`.
18
+ */
19
+ template <typename T, typename ListElemT>
20
+ class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
21
+ public:
22
+ using elem_type = ListElemT;
23
+ using list_type = ArrayRef<elem_type>;
24
+
25
+ /*
26
+ * These `unwrap` static methods unwraps the inner containers out
27
+ * of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
28
+ * the macro `TORCH_ILISTREF_UNWRAP` is called.
29
+ */
30
+ static const list_type& unwrap(const IListRef<T>& ilist) {
31
+ return ilist.payload_.unboxed;
32
+ }
33
+
34
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
35
+ return it.payload_.unboxed_iterator;
36
+ }
37
+
38
+ static const typename list_type::const_iterator& unwrap(
39
+ const IListRefIterator<T>& it) {
40
+ return it.payload_.unboxed_iterator;
41
+ }
42
+
43
+ /*
44
+ * We have these function (besides the `unwrap`s above) because the
45
+ * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
46
+ * weren't syntactically equal for the existing tags at the time
47
+ * (`Unboxed` and `Boxed`).
48
+ */
49
+ static IListRefConstRef<T> front(const list_type& lst) {
50
+ return lst.front();
51
+ }
52
+
53
+ static IListRefConstRef<T> iterator_get(
54
+ const typename list_type::const_iterator& it) {
55
+ return *it;
56
+ }
57
+ };
58
+
59
+ /*
60
+ * Specializations of `IListRefTagImplBase` that implement the default
61
+ * implementation for `IListRefTag::Boxed`.
62
+ */
63
+ template <typename T, typename ListElemT>
64
+ class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
65
+ public:
66
+ using elem_type = ListElemT;
67
+ using list_type = List<elem_type>;
68
+
69
+ static const list_type& unwrap(const IListRef<T>& ilist) {
70
+ return *ilist.payload_.boxed;
71
+ }
72
+
73
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
74
+ return it.payload_.boxed_iterator;
75
+ }
76
+
77
+ static const typename list_type::const_iterator& unwrap(
78
+ const IListRefIterator<T>& it) {
79
+ return it.payload_.boxed_iterator;
80
+ }
81
+
82
+ static IListRefConstRef<T> front(const list_type& lst) {
83
+ return lst[0];
84
+ }
85
+
86
+ static IListRefConstRef<T> iterator_get(
87
+ const typename list_type::const_iterator& it) {
88
+ return (*it).get().toTensor();
89
+ }
90
+ };
91
+
92
+ /*
93
+ * Specializations of `IListRefTagImplBase` that implement the default
94
+ * implementation for `IListRefTag::Materialized`.
95
+ */
96
+ template <typename T>
97
+ class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
98
+ public:
99
+ using elem_type = MaterializedIListRefElem<T>;
100
+ using list_type = MaterializedIListRef<T>;
101
+
102
+ static const list_type& unwrap(const IListRef<T>& ilist) {
103
+ return *ilist.payload_.materialized;
104
+ }
105
+
106
+ static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
107
+ return it.payload_.materialized_iterator;
108
+ }
109
+
110
+ static const typename list_type::const_iterator& unwrap(
111
+ const IListRefIterator<T>& it) {
112
+ return it.payload_.materialized_iterator;
113
+ }
114
+
115
+ static IListRefConstRef<T> front(const list_type& lst) {
116
+ return lst[0];
117
+ }
118
+
119
+ static IListRefConstRef<T> iterator_get(
120
+ const typename list_type::const_iterator& it) {
121
+ return *it;
122
+ }
123
+ };
124
+
125
+ /*
126
+ * [Note: ITensorListRef]
127
+ * Specializations necessary for `IListRef<at::Tensor>` type.
128
+ *
129
+ * Since the default implementations are usually done with supporting
130
+ * `Tensor` in mind, we only have to inherit from the base implementations.
131
+ */
132
+ template <>
133
+ class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
134
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
135
+
136
+ template <>
137
+ class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
138
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
139
+
140
+ template <>
141
+ class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
142
+ : public IListRefTagImplBase<
143
+ IListRefTag::Materialized,
144
+ at::Tensor,
145
+ MaterializedIListRefElem<at::Tensor>> {};
146
+
147
+ /*
148
+ * [Note: IOptTensorListRef]
149
+ * Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
150
+ *
151
+ * We can't get an `at::OptionalTensorRef` directly from an instance of
152
+ * `List<optional<Tensor>>` (the type that corresponds to the boxed world).
153
+ *
154
+ * So, the default implementation won't help us. Thus, we have to implement
155
+ * this method ourselves.
156
+ */
157
+ template <>
158
+ class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
159
+ : public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
160
+
161
+ template <>
162
+ class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
163
+ : public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> {
164
+
165
+ public:
166
+ /*
167
+ * Given an instance of the types corresponding to the `Boxed` tag, we override
168
+ * the default implementation, so that we can return a `at::OptionalTensorRef`.
169
+ */
170
+ static IListRefConstRef<at::OptionalTensorRef> iterator_get(
171
+ const typename list_type::const_iterator& it) {
172
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference")
173
+ const auto& ivalue = (*it).get();
174
+ C10_DIAGNOSTIC_POP()
175
+ if (!ivalue.isNone()) {
176
+ const auto& tensor = ivalue.toTensor();
177
+ return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
178
+ }
179
+ return {};
180
+ }
181
+ };
182
+
183
+ template <>
184
+ class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
185
+ : public IListRefTagImplBase<
186
+ IListRefTag::Materialized,
187
+ at::OptionalTensorRef,
188
+ MaterializedIListRefElem<at::OptionalTensorRef>> {};
189
+
190
+ } // namespace c10::detail
191
+
192
+
193
+ namespace at {
194
+
195
+ // [Note: ITensorListRef]
196
+ using ITensorListRef = c10::IListRef<at::Tensor>;
197
+ using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
198
+ using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
199
+ // [Note: IOptTensorListRef]
200
+ using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
201
+ using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
202
+ using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
203
+
204
+ } // namespace at
205
+
206
+ #else
207
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
208
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // The legacy mechanism for dispatching operators in ATen is a Type
5
+ // object, which is essentially a giant virtual dispatch table
6
+ // for every operation we support dynamically dispatching over.
7
+ //
8
+ // This has been deprecated in favor of ATenDispatch, and in the future,
9
+ // c10 dispatcher.
10
+ // TODO: Clean up what remains here
11
+
12
+ #include <c10/core/impl/LocalDispatchKeySet.h>
13
+
14
+ namespace at {
15
+
16
+ // A RAII, thread local (!) guard that will disable dispatch to variable
17
+ // handler.
18
+ //
19
+ // NOTE [ Treating Variables as non-Variables in type dispatch ]
20
+ //
21
+ // What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes
22
+ // dispatches on ATen functions to go to the non-variable implementation,
23
+ // bypassing autograd handling (and also profiling and tracing).
24
+ //
25
+ // To understand why this guard exists, it's helpful to understand the history
26
+ // behind how Variable was implemented. Previously, Variables were implemented
27
+ // as a wrapper on Tensors; so the act of processing a Variable involved
28
+ // unwrapping the underlying Tensor, and then calling the underlying base
29
+ // operation on /that/ operation
30
+ //
31
+ // However, after the Variable/Tensor merge, there is no concept of unwrapping
32
+ // a tensor anymore. If you just call the operation on the same variable
33
+ // again inside your VariableType handler, you'll dispatch back to
34
+ // VariableType, which is not what we want.
35
+ //
36
+ // The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which
37
+ // when enabled will cause `legacyTensorType()` and `getType()` to always return
38
+ // non-Variable type, even if the tensor being called on is a variable.
39
+
40
+ /* Note [AutoDispatchBelowAutograd]
41
+ * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used
42
+ * for kernel implementations and customized C++ kernels.
43
+ * If you are looking for a guard to run workload in inference mode, please use
44
+ * c10::InferenceMode RAII which is user facing API.
45
+ * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode)
46
+ * was used in the user code for inference-only workload, this was under risk of
47
+ * producing wrong results silently in some edge cases. For example:
48
+ * ```
49
+ * torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
50
+ * torch::Tensor out = s * s;
51
+ * {
52
+ * at::AutoDispatchBelowAutograd guard;
53
+ * s.add_(1); // Skips version bump on `s`.
54
+ * }
55
+ * // WRONG GRADIENT! s.grad() are now computed using `s` value after the
56
+ * // inplace update.
57
+ * out.backward(torch::ones_like(out));
58
+ * ```
59
+ * Users should use `c10::InferenceMode` here so that it'll properly throw an
60
+ * error saying "one of the variables needed for gradient computation has be modified."
61
+ */
62
+ struct TORCH_API AutoDispatchBelowAutograd {
63
+ AutoDispatchBelowAutograd() :
64
+ autograd_guard_(c10::autograd_dispatch_keyset) {
65
+ }
66
+
67
+ // disable all autograd dispatch keys
68
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
69
+ };
70
+
71
+ // TODO: AutoNonVariableTypeMode should be removed in release 1.10.
72
+ struct TORCH_API AutoNonVariableTypeMode {
73
+ AutoNonVariableTypeMode(bool enabled = true) :
74
+ autograd_guard_(c10::autograd_dispatch_keyset) {
75
+ TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. "
76
+ "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, "
77
+ "If you are looking for a user facing API to enable running your inference-only "
78
+ "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code "
79
+ "is under risk of producing silent wrong result in some edge cases. "
80
+ "See Note [AutoDispatchBelowAutograd] for more details.");
81
+ TORCH_INTERNAL_ASSERT(enabled);
82
+ }
83
+
84
+ // disable all autograd dispatch keys
85
+ c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
86
+ };
87
+
88
+ struct TORCH_API AutoDispatchSkipFunctionalize {
89
+ AutoDispatchSkipFunctionalize() :
90
+ dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) {
91
+ }
92
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
93
+ };
94
+
95
+ /* Note [AutoDispatchBelowADInplaceOrView]
96
+ * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
97
+ * before we split inplace & view ops out of VariableType kernel.
98
+ * Note this guard is used in VariableType kernels for functional ops
99
+ * as well as ADInplaceOrView kernels for inplace/view ops to enforce the
100
+ * Invariant:
101
+ * Once you are in VariableType/ADInplaceOrView kernel for an op,
102
+ * you never go back to a kernel on same dispatch key until
103
+ * you finish the current op.
104
+ */
105
+ struct TORCH_API AutoDispatchBelowADInplaceOrView {
106
+ AutoDispatchBelowADInplaceOrView() :
107
+ dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) {
108
+ }
109
+ // disable Autograd & ADInplaceOrView dispatch keys
110
+ c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
111
+ };
112
+ } // namespace at
113
+
114
+ #else
115
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
116
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/ivalue_to.h>
5
+ #include <ATen/core/jit_type_base.h>
6
+ #include <c10/macros/Macros.h>
7
+ #include <c10/macros/Export.h>
8
+ #include <c10/util/TypeTraits.h>
9
+ #include <c10/util/TypeList.h>
10
+ #include <c10/util/intrusive_ptr.h>
11
+ #include <c10/util/ArrayRef.h>
12
+ #include <optional>
13
+ #include <vector>
14
+
15
+ namespace at {
16
+ class Tensor;
17
+ }
18
+ namespace c10 {
19
+ struct IValue;
20
+ template<class T> class List;
21
+ struct Type;
22
+
23
+ namespace detail {
24
+
25
+ struct ListImpl final : public c10::intrusive_ptr_target {
26
+ using list_type = std::vector<IValue>;
27
+
28
+ explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
29
+
30
+ list_type list;
31
+
32
+ TypePtr elementType;
33
+
34
+ intrusive_ptr<ListImpl> copy() const {
35
+ return make_intrusive<ListImpl>(list, elementType);
36
+ }
37
+ friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
38
+ };
39
+ }
40
+
41
+ namespace impl {
42
+
43
+ template<class T, class Iterator> class ListIterator;
44
+
45
+ template<class T, class Iterator> class ListElementReference;
46
+
47
+ template<class T, class Iterator>
48
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
49
+
50
+ template<class T, class Iterator>
51
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
52
+
53
+ template<class T, class Iterator>
54
+ bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
55
+
56
+ template<class T>
57
+ struct ListElementConstReferenceTraits {
58
+ // In the general case, we use IValue::to().
59
+ using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
60
+ };
61
+
62
+ // There is no to() overload for std::optional<std::string>.
63
+ template<>
64
+ struct ListElementConstReferenceTraits<std::optional<std::string>> {
65
+ using const_reference = std::optional<std::reference_wrapper<const std::string>>;
66
+ };
67
+
68
+ template<class T, class Iterator>
69
+ class ListElementReference final {
70
+ public:
71
+ operator std::conditional_t<
72
+ std::is_reference_v<typename c10::detail::
73
+ ivalue_to_const_ref_overload_return<T>::type>,
74
+ const T&,
75
+ T>() const;
76
+
77
+ ListElementReference& operator=(T&& new_value) &&;
78
+
79
+ ListElementReference& operator=(const T& new_value) &&;
80
+
81
+ // assigning another ref to this assigns the underlying value
82
+ ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
83
+
84
+ const IValue& get() const& {
85
+ return *iterator_;
86
+ }
87
+
88
+ friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
89
+
90
+ ListElementReference(const ListElementReference&) = delete;
91
+ ListElementReference& operator=(const ListElementReference&) = delete;
92
+ ~ListElementReference() = default;
93
+
94
+ private:
95
+ ListElementReference(Iterator iter)
96
+ : iterator_(iter) {}
97
+
98
+ // allow moving, but only our friends (i.e. the List class) can move us
99
+ ListElementReference(ListElementReference&&) noexcept = default;
100
+ ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
101
+ iterator_ = std::move(rhs.iterator_);
102
+ return *this;
103
+ }
104
+
105
+ friend class List<T>;
106
+ friend class ListIterator<T, Iterator>;
107
+
108
+ Iterator iterator_;
109
+ };
110
+
111
+ // this wraps vector::iterator to make sure user code can't rely
112
+ // on it being the type of the underlying vector.
113
+ template <class T, class Iterator>
114
+ class ListIterator final {
115
+ public:
116
+ // C++17 friendly std::iterator implementation
117
+ using iterator_category = std::random_access_iterator_tag;
118
+ using value_type = T;
119
+ using difference_type = std::ptrdiff_t;
120
+ using pointer = T*;
121
+ using reference = ListElementReference<T, Iterator>;
122
+
123
+ explicit ListIterator() = default;
124
+ ~ListIterator() = default;
125
+
126
+ ListIterator(const ListIterator&) = default;
127
+ ListIterator(ListIterator&&) noexcept = default;
128
+ ListIterator& operator=(const ListIterator&) = default;
129
+ ListIterator& operator=(ListIterator&&) noexcept = default;
130
+
131
+ ListIterator& operator++() {
132
+ ++iterator_;
133
+ return *this;
134
+ }
135
+
136
+ ListIterator operator++(int) {
137
+ ListIterator copy(*this);
138
+ ++*this;
139
+ return copy;
140
+ }
141
+
142
+ ListIterator& operator--() {
143
+ --iterator_;
144
+ return *this;
145
+ }
146
+
147
+ ListIterator operator--(int) {
148
+ ListIterator copy(*this);
149
+ --*this;
150
+ return copy;
151
+ }
152
+
153
+ ListIterator& operator+=(typename List<T>::size_type offset) {
154
+ iterator_ += offset;
155
+ return *this;
156
+ }
157
+
158
+ ListIterator& operator-=(typename List<T>::size_type offset) {
159
+ iterator_ -= offset;
160
+ return *this;
161
+ }
162
+
163
+ ListIterator operator+(typename List<T>::size_type offset) const {
164
+ return ListIterator{iterator_ + offset};
165
+ }
166
+
167
+ ListIterator operator-(typename List<T>::size_type offset) const {
168
+ return ListIterator{iterator_ - offset};
169
+ }
170
+
171
+ friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
172
+ return lhs.iterator_ - rhs.iterator_;
173
+ }
174
+
175
+ ListElementReference<T, Iterator> operator*() const {
176
+ return {iterator_};
177
+ }
178
+
179
+ ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
180
+ return {iterator_ + offset};
181
+ }
182
+
183
+ private:
184
+ explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
185
+
186
+ Iterator iterator_;
187
+
188
+ friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
189
+ return lhs.iterator_ == rhs.iterator_;
190
+ }
191
+
192
+ friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
193
+ return !(lhs == rhs);
194
+ }
195
+
196
+ friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
197
+ return lhs.iterator_ < rhs.iterator_;
198
+ }
199
+
200
+ friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
201
+ return lhs.iterator_ <= rhs.iterator_;
202
+ }
203
+
204
+ friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
205
+ return lhs.iterator_ > rhs.iterator_;
206
+ }
207
+
208
+ friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
209
+ return lhs.iterator_ >= rhs.iterator_;
210
+ }
211
+
212
+ friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
213
+ friend class List<T>;
214
+ };
215
+
216
+ template<class T> List<T> toTypedList(List<IValue> list);
217
+ template<class T> List<IValue> toList(List<T>&& list);
218
+ template<class T> List<IValue> toList(const List<T>& list);
219
+ const IValue* ptr_to_first_element(const List<IValue>& list);
220
+ }
221
+
222
+ /**
223
+ * An object of this class stores a list of values of type T.
224
+ *
225
+ * This is a pointer type. After a copy, both Lists
226
+ * will share the same storage:
227
+ *
228
+ * > List<int> a;
229
+ * > List<int> b = a;
230
+ * > b.push_back("three");
231
+ * > ASSERT("three" == a.get(0));
232
+ *
233
+ * We use this class in the PyTorch kernel API instead of
234
+ * std::vector<T>, because that allows us to do optimizations
235
+ * and switch out the underlying list implementation without
236
+ * breaking backwards compatibility for the kernel API.
237
+ */
238
+ template<class T>
239
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
240
+ class List final {
241
+ private:
242
+ // This is an intrusive_ptr because List is a pointer type.
243
+ // Invariant: This will never be a nullptr, there will always be a valid
244
+ // ListImpl.
245
+ c10::intrusive_ptr<c10::detail::ListImpl> impl_;
246
+
247
+ using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
248
+ using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
249
+
250
+ public:
251
+ using value_type = T;
252
+ using size_type = typename c10::detail::ListImpl::list_type::size_type;
253
+ using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
254
+ using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
255
+ using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
256
+
257
+ /**
258
+ * Constructs an empty list.
259
+ */
260
+ explicit List();
261
+
262
+ /**
263
+ * Constructs a list with some initial values.
264
+ * Example:
265
+ * List<int> a({2, 3, 4});
266
+ */
267
+ List(std::initializer_list<T> initial_values);
268
+ explicit List(ArrayRef<T> initial_values);
269
+
270
+ /**
271
+ * Create a generic list with runtime type information.
272
+ * This only works for c10::impl::GenericList and is not part of the public API
273
+ * but only supposed to be used internally by PyTorch.
274
+ */
275
+ explicit List(TypePtr elementType);
276
+
277
+ List(const List&) = default;
278
+ List& operator=(const List&) = default;
279
+ ~List() = default;
280
+
281
+ /**
282
+ * Create a new List pointing to a deep copy of the same data.
283
+ * The List returned is a new list with separate storage.
284
+ * Changes in it are not reflected in the original list or vice versa.
285
+ */
286
+ List copy() const;
287
+
288
+ /**
289
+ * Returns the element at specified location pos, with bounds checking.
290
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
291
+ */
292
+ internal_const_reference_type get(size_type pos) const;
293
+
294
+ /**
295
+ * Moves out the element at the specified location pos and returns it, with bounds checking.
296
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
297
+ * The list contains an invalid element at position pos afterwards. Any operations
298
+ * on it before re-setting it are invalid.
299
+ */
300
+ value_type extract(size_type pos) const;
301
+
302
+ /**
303
+ * Returns a reference to the element at specified location pos, with bounds checking.
304
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
305
+ *
306
+ * You cannot store the reference, but you can read it and assign new values to it:
307
+ *
308
+ * List<int64_t> list = ...;
309
+ * list[2] = 5;
310
+ * int64_t v = list[1];
311
+ */
312
+ internal_const_reference_type operator[](size_type pos) const;
313
+
314
+ internal_reference_type operator[](size_type pos);
315
+
316
+ /**
317
+ * Assigns a new value to the element at location pos.
318
+ */
319
+ void set(size_type pos, const value_type& value) const;
320
+
321
+ /**
322
+ * Assigns a new value to the element at location pos.
323
+ */
324
+ void set(size_type pos, value_type&& value) const;
325
+
326
+ /**
327
+ * Returns an iterator to the first element of the container.
328
+ * If the container is empty, the returned iterator will be equal to end().
329
+ */
330
+ iterator begin() const;
331
+
332
+ /**
333
+ * Returns an iterator to the element following the last element of the container.
334
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
335
+ */
336
+ iterator end() const;
337
+
338
+ /**
339
+ * Checks if the container has no elements.
340
+ */
341
+ bool empty() const;
342
+
343
+ /**
344
+ * Returns the number of elements in the container
345
+ */
346
+ size_type size() const;
347
+
348
+ /**
349
+ * Increase the capacity of the vector to a value that's greater or equal to new_cap.
350
+ */
351
+ void reserve(size_type new_cap) const;
352
+
353
+ /**
354
+ * Erases all elements from the container. After this call, size() returns zero.
355
+ * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
356
+ */
357
+ void clear() const;
358
+
359
+ /**
360
+ * Inserts value before pos.
361
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
362
+ */
363
+ iterator insert(iterator pos, const T& value) const;
364
+
365
+ /**
366
+ * Inserts value before pos.
367
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
368
+ */
369
+ iterator insert(iterator pos, T&& value) const;
370
+
371
+ /**
372
+ * Inserts a new element into the container directly before pos.
373
+ * The new element is constructed with the given arguments.
374
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
375
+ */
376
+ template<class... Args>
377
+ iterator emplace(iterator pos, Args&&... value) const;
378
+
379
+ /**
380
+ * Appends the given element value to the end of the container.
381
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
382
+ */
383
+ void push_back(const T& value) const;
384
+
385
+ /**
386
+ * Appends the given element value to the end of the container.
387
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
388
+ */
389
+ void push_back(T&& value) const;
390
+
391
+ /**
392
+ * Appends the given list to the end of the container. Uses at most one memory allocation.
393
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
394
+ */
395
+ void append(List<T> lst) const;
396
+
397
+ /**
398
+ * Appends the given element value to the end of the container.
399
+ * The new element is constructed with the given arguments.
400
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
401
+ */
402
+ template<class... Args>
403
+ void emplace_back(Args&&... args) const;
404
+
405
+ /**
406
+ * Removes the element at pos.
407
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
408
+ */
409
+ iterator erase(iterator pos) const;
410
+
411
+ /**
412
+ * Removes the elements in the range [first, last).
413
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
414
+ */
415
+ iterator erase(iterator first, iterator last) const;
416
+
417
+ /**
418
+ * Removes the last element of the container.
419
+ * Calling pop_back on an empty container is undefined.
420
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
421
+ */
422
+ void pop_back() const;
423
+
424
+ /**
425
+ * Resizes the container to contain count elements.
426
+ * If the current size is less than count, additional default-inserted elements are appended.
427
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
428
+ */
429
+ void resize(size_type count) const;
430
+
431
+ /**
432
+ * Resizes the container to contain count elements.
433
+ * If the current size is less than count, additional copies of value are appended.
434
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
435
+ */
436
+ void resize(size_type count, const T& value) const;
437
+
438
+ /**
439
+ * Value equality comparison. This function implements Python-like semantics for
440
+ * equality: two lists with the same identity (e.g. same pointer) trivially
441
+ * compare equal, otherwise each element is compared for equality.
442
+ */
443
+ template <class T_>
444
+ friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
445
+
446
+ template <class T_>
447
+ friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
448
+
449
+ /**
450
+ * Identity comparison. Returns true if and only if `rhs` represents the same
451
+ * List object as `this`.
452
+ */
453
+ bool is(const List<T>& rhs) const;
454
+
455
+ std::vector<T> vec() const;
456
+
457
+ /**
458
+ * Returns the number of Lists currently pointing to this same list.
459
+ * If this is the only instance pointing to this list, returns 1.
460
+ */
461
+ // TODO Test use_count
462
+ size_t use_count() const;
463
+
464
+ TypePtr elementType() const;
465
+
466
+ // See [unsafe set type] for why this exists.
467
+ void unsafeSetElementType(TypePtr t);
468
+
469
+ private:
470
+ explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
471
+ explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
472
+ friend struct IValue;
473
+ template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
474
+ template<class T_> friend List<IValue> impl::toList(List<T_>&&);
475
+ template<class T_> friend List<IValue> impl::toList(const List<T_>&);
476
+ friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
477
+ };
478
+
479
+ namespace impl {
480
+ // GenericList is how IValue stores lists. It is, however, not part of the
481
+ // public API. Kernels should use Lists with concrete types instead
482
+ // (maybe except for some internal prim ops).
483
+ using GenericList = List<IValue>;
484
+
485
+ }
486
+ }
487
+
488
+ namespace torch {
489
+ template<class T> using List = c10::List<T>;
490
+ }
491
+
492
+ #include <ATen/core/List_inl.h> // IWYU pragma: keep
493
+
494
+ #else
495
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
496
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/jit_type_base.h>
5
+ #include <ATen/core/ivalue.h>
6
+
7
+ namespace c10 {
8
+
9
+ template<class T> decltype(auto) getTypePtr();
10
+ std::string toString(const Type& type);
11
+
12
+ template<class T>
13
+ List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
14
+ : impl_(std::move(elements)) {}
15
+
16
+ template<class T>
17
+ List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
18
+ : impl_(elements) {}
19
+
20
+ template<class T>
21
+ List<T>::List()
22
+ : List(make_intrusive<c10::detail::ListImpl>(
23
+ typename c10::detail::ListImpl::list_type(),
24
+ getTypePtr<T>())) {
25
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
26
+ }
27
+
28
+ template<class T>
29
+ List<T>::List(ArrayRef<T> values)
30
+ : List(make_intrusive<c10::detail::ListImpl>(
31
+ typename c10::detail::ListImpl::list_type(),
32
+ getTypePtr<T>())) {
33
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
34
+ impl_->list.reserve(values.size());
35
+ for (const T& element : values) {
36
+ impl_->list.push_back(element);
37
+ }
38
+ }
39
+
40
+ template<class T>
41
+ List<T>::List(std::initializer_list<T> initial_values)
42
+ : List(ArrayRef<T>(initial_values)) {
43
+ static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
44
+ }
45
+
46
+ template<class T>
47
+ List<T>::List(TypePtr elementType)
48
+ : List(make_intrusive<c10::detail::ListImpl>(
49
+ typename c10::detail::ListImpl::list_type(),
50
+ std::move(elementType))) {
51
+ static_assert(std::is_same_v<T, IValue> || std::is_same_v<T, c10::intrusive_ptr<ivalue::Future>>,
52
+ "This constructor is only valid for c10::impl::GenericList or List<Future>.");
53
+ }
54
+
55
+ namespace impl {
56
+ template<class T>
57
+ List<T> toTypedList(impl::GenericList list) {
58
+ // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
59
+ // because upcasting would allow people to add types into the new list that would break the old list.
60
+ // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
61
+ // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
62
+ // without having to copy it. This is also used to provide backwards compatibility with some old models
63
+ // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
64
+ // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
65
+ // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
66
+ TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
67
+ || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
68
+ , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
69
+ return List<T>(std::move(list.impl_));
70
+ }
71
+
72
+ template<class T>
73
+ impl::GenericList toList(List<T>&& list) {
74
+ return GenericList(std::move(list.impl_));
75
+ }
76
+ template<class T>
77
+ impl::GenericList toList(const List<T>& list) {
78
+ return GenericList(list.impl_);
79
+ }
80
+ }
81
+
82
+ template<class T>
83
+ List<T> List<T>::copy() const {
84
+ return List<T>(impl_->copy());
85
+ }
86
+
87
+ namespace detail {
88
+ template<class T>
89
+ T list_element_to(T element) {
90
+ return element;
91
+ }
92
+ template<class T>
93
+ T list_element_to(const IValue& element) {
94
+ return element.template to<T>();
95
+ }
96
+ template<class T>
97
+ T list_element_to(IValue&& element) {
98
+ return std::move(element).template to<T>();
99
+ }
100
+ template<class T>
101
+ struct ListElementFrom {
102
+ static IValue from(const T& element) {
103
+ return element;
104
+ }
105
+ static IValue from(T&& element) {
106
+ return std::move(element);
107
+ }
108
+ };
109
+ template<>
110
+ struct ListElementFrom<IValue> {
111
+ static const IValue& from(const IValue& element) {
112
+ return element;
113
+ }
114
+ static IValue&& from(IValue&& element) {
115
+ return std::move(element);
116
+ }
117
+ };
118
+ }
119
+
120
+ namespace impl {
121
+
122
+ template <class T, class Iterator>
123
+ ListElementReference<T, Iterator>::operator std::conditional_t<
124
+ std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
125
+ T>::type>,
126
+ const T&,
127
+ T>() const {
128
+ return iterator_->template to<T>();
129
+ }
130
+
131
+ template<class T, class Iterator>
132
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
133
+ *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
134
+ return *this;
135
+ }
136
+
137
+ template<class T, class Iterator>
138
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
139
+ *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
140
+ return *this;
141
+ }
142
+
143
+ template<class T, class Iterator>
144
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
145
+ *iterator_ = *rhs.iterator_;
146
+ return *this;
147
+ }
148
+
149
+ template<class T, class Iterator>
150
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
151
+ std::swap(*lhs.iterator_, *rhs.iterator_);
152
+ }
153
+
154
+ template<class T, class Iterator>
155
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
156
+ const T& lhs_tmp = lhs;
157
+ return lhs_tmp == rhs;
158
+ }
159
+
160
+ template<class T, class Iterator>
161
+ inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
162
+ return rhs == lhs;
163
+ }
164
+
165
+ template<class T>
166
+ inline typename ListElementConstReferenceTraits<T>::const_reference
167
+ list_element_to_const_ref(const IValue& element) {
168
+ return element.template to<T>();
169
+ }
170
+
171
+ template<>
172
+ inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
173
+ list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
174
+ return element.toOptionalStringRef();
175
+ }
176
+
177
+ } // namespace impl
178
+
179
+ template<class T>
180
+ void List<T>::set(size_type pos, const value_type& value) const {
181
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
182
+ }
183
+
184
+ template<class T>
185
+ void List<T>::set(size_type pos, value_type&& value) const {
186
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
187
+ }
188
+
189
+ template<class T>
190
+ typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
191
+ return operator[](pos);
192
+ }
193
+
194
+ template<class T>
195
+ typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
196
+ return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
197
+ }
198
+
199
+ template<class T>
200
+ typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
201
+ static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
202
+ return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
203
+ }
204
+
205
+ template<class T>
206
+ typename List<T>::value_type List<T>::extract(size_type pos) const {
207
+ auto& elem = impl_->list.at(pos);
208
+ auto result = c10::detail::list_element_to<T>(std::move(elem));
209
+ // Reset the list element to a T() instead of None to keep it correctly typed
210
+ elem = c10::detail::ListElementFrom<T>::from(T{});
211
+ return result;
212
+ }
213
+
214
+ template<class T>
215
+ typename List<T>::iterator List<T>::begin() const {
216
+ return iterator(impl_->list.begin());
217
+ }
218
+
219
+ template<class T>
220
+ typename List<T>::iterator List<T>::end() const {
221
+ return iterator(impl_->list.end());
222
+ }
223
+
224
+ template<class T>
225
+ bool List<T>::empty() const {
226
+ return impl_->list.empty();
227
+ }
228
+
229
+ template<class T>
230
+ typename List<T>::size_type List<T>::size() const {
231
+ return impl_->list.size();
232
+ }
233
+
234
+ template<class T>
235
+ void List<T>::reserve(size_type new_cap) const {
236
+ impl_->list.reserve(new_cap);
237
+ }
238
+
239
+ template<class T>
240
+ void List<T>::clear() const {
241
+ impl_->list.clear();
242
+ }
243
+
244
+ template<class T>
245
+ typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
246
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
247
+ }
248
+
249
+ template<class T>
250
+ typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
251
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
252
+ }
253
+
254
+ template<class T>
255
+ template<class... Args>
256
+ typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
257
+ // TODO Use list_element_from?
258
+ return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
259
+ }
260
+
261
+ template<class T>
262
+ void List<T>::push_back(const T& value) const {
263
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
264
+ }
265
+
266
+ template<class T>
267
+ void List<T>::push_back(T&& value) const {
268
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
269
+ }
270
+
271
+ template<class T>
272
+ void List<T>::append(List<T> b) const {
273
+ if (b.use_count() == 1) {
274
+ impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
275
+ } else {
276
+ impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
277
+ }
278
+ }
279
+
280
+ template<class T>
281
+ template<class... Args>
282
+ void List<T>::emplace_back(Args&&... args) const {
283
+ // TODO Use list_element_from?
284
+ impl_->list.push_back(T(std::forward<Args>(args)...));
285
+ }
286
+
287
+ template<class T>
288
+ typename List<T>::iterator List<T>::erase(iterator pos) const {
289
+ return iterator { impl_->list.erase(pos.iterator_) };
290
+ }
291
+
292
+ template<class T>
293
+ typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
294
+ return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
295
+ }
296
+
297
+ template<class T>
298
+ void List<T>::pop_back() const {
299
+ impl_->list.pop_back();
300
+ }
301
+
302
+ template<class T>
303
+ void List<T>::resize(size_type count) const {
304
+ impl_->list.resize(count, T{});
305
+ }
306
+
307
+ template<class T>
308
+ void List<T>::resize(size_type count, const T& value) const {
309
+ impl_->list.resize(count, value);
310
+ }
311
+
312
+ template<class T>
313
+ bool operator==(const List<T>& lhs, const List<T>& rhs) {
314
+ // Lists with the same identity trivially compare equal.
315
+ if (lhs.impl_ == rhs.impl_) {
316
+ return true;
317
+ }
318
+
319
+ // Otherwise, just compare values directly.
320
+ return *lhs.impl_ == *rhs.impl_;
321
+ }
322
+
323
+ template<class T>
324
+ bool operator!=(const List<T>& lhs, const List<T>& rhs) {
325
+ return !(lhs == rhs);
326
+ }
327
+
328
+ template<class T>
329
+ bool List<T>::is(const List<T>& rhs) const {
330
+ return this->impl_ == rhs.impl_;
331
+ }
332
+
333
+ template<class T>
334
+ std::vector<T> List<T>::vec() const {
335
+ std::vector<T> result(begin(), end());
336
+ return result;
337
+ }
338
+
339
+ template<class T>
340
+ size_t List<T>::use_count() const {
341
+ return impl_.use_count();
342
+ }
343
+
344
+ template <class T>
345
+ TypePtr List<T>::elementType() const {
346
+ return impl_->elementType;
347
+ }
348
+
349
+ template <class T>
350
+ void List<T>::unsafeSetElementType(TypePtr t) {
351
+ impl_->elementType = std::move(t);
352
+ }
353
+
354
+ }
355
+
356
+ #else
357
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
358
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/util/irange.h>
5
+
6
+ // define constants like M_PI and C keywords for MSVC
7
+ #ifdef _MSC_VER
8
+ #ifndef _USE_MATH_DEFINES
9
+ #define _USE_MATH_DEFINES
10
+ #endif
11
+ #include <math.h>
12
+ #endif
13
+
14
+ #include <array>
15
+ #include <cmath>
16
+ #include <cstdint>
17
+
18
+ namespace at {
19
+
20
+ constexpr int MERSENNE_STATE_N = 624;
21
+ constexpr int MERSENNE_STATE_M = 397;
22
+ constexpr uint32_t MATRIX_A = 0x9908b0df;
23
+ constexpr uint32_t UMASK = 0x80000000;
24
+ constexpr uint32_t LMASK = 0x7fffffff;
25
+
26
+ /**
27
+ * Note [Mt19937 Engine implementation]
28
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29
+ * Originally implemented in:
30
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
31
+ * and modified with C++ constructs. Moreover the state array of the engine
32
+ * has been modified to hold 32 bit uints instead of 64 bits.
33
+ *
34
+ * Note that we reimplemented mt19937 instead of using std::mt19937 because,
35
+ * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
36
+ * by default and following are the benchmark numbers (benchmark code can be found at
37
+ * https://github.com/syed-ahmed/benchmark-rngs):
38
+ *
39
+ * with -O2
40
+ * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
41
+ * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
42
+ * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
43
+ * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
44
+ *
45
+ * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
46
+ * however we can't use std::uniform_real_distribution because of this bug:
47
+ * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
48
+ * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
49
+ * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
50
+ * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
51
+ * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
52
+ *
53
+ * Copyright notice:
54
+ * A C-program for MT19937, with initialization improved 2002/2/10.
55
+ * Coded by Takuji Nishimura and Makoto Matsumoto.
56
+ * This is a faster version by taking Shawn Cokus's optimization,
57
+ * Matthe Bellew's simplification, Isaku Wada's real version.
58
+ *
59
+ * Before using, initialize the state by using init_genrand(seed)
60
+ * or init_by_array(init_key, key_length).
61
+ *
62
+ * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
63
+ * All rights reserved.
64
+ *
65
+ * Redistribution and use in source and binary forms, with or without
66
+ * modification, are permitted provided that the following conditions
67
+ * are met:
68
+ *
69
+ * 1. Redistributions of source code must retain the above copyright
70
+ * notice, this list of conditions and the following disclaimer.
71
+ *
72
+ * 2. Redistributions in binary form must reproduce the above copyright
73
+ * notice, this list of conditions and the following disclaimer in the
74
+ * documentation and/or other materials provided with the distribution.
75
+ *
76
+ * 3. The names of its contributors may not be used to endorse or promote
77
+ * products derived from this software without specific prior written
78
+ * permission.
79
+ *
80
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
81
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
82
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
83
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
84
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
85
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
86
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
87
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
88
+ * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
89
+ * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
90
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
91
+ *
92
+ *
93
+ * Any feedback is very welcome.
94
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
95
+ * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
96
+ */
97
+
98
+ /**
99
+ * mt19937_data_pod is used to get POD data in and out
100
+ * of mt19937_engine. Used in torch.get_rng_state and
101
+ * torch.set_rng_state functions.
102
+ */
103
+ struct mt19937_data_pod {
104
+ uint64_t seed_;
105
+ int left_;
106
+ bool seeded_;
107
+ uint32_t next_;
108
+ std::array<uint32_t, MERSENNE_STATE_N> state_;
109
+ };
110
+
111
+ class mt19937_engine {
112
+ public:
113
+
114
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
115
+ inline explicit mt19937_engine(uint64_t seed = 5489) {
116
+ init_with_uint32(seed);
117
+ }
118
+
119
+ inline mt19937_data_pod data() const {
120
+ return data_;
121
+ }
122
+
123
+ inline void set_data(const mt19937_data_pod& data) {
124
+ data_ = data;
125
+ }
126
+
127
+ inline uint64_t seed() const {
128
+ return data_.seed_;
129
+ }
130
+
131
+ inline bool is_valid() {
132
+ if ((data_.seeded_ == true)
133
+ && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
134
+ && (data_.next_ <= MERSENNE_STATE_N)) {
135
+ return true;
136
+ }
137
+ return false;
138
+ }
139
+
140
+ inline uint32_t operator()() {
141
+ if (--(data_.left_) == 0) {
142
+ next_state();
143
+ }
144
+ uint32_t y = *(data_.state_.data() + data_.next_++);
145
+ y ^= (y >> 11);
146
+ y ^= (y << 7) & 0x9d2c5680;
147
+ y ^= (y << 15) & 0xefc60000;
148
+ y ^= (y >> 18);
149
+
150
+ return y;
151
+ }
152
+
153
+ private:
154
+ mt19937_data_pod data_;
155
+
156
+ inline void init_with_uint32(uint64_t seed) {
157
+ data_.seed_ = seed;
158
+ data_.seeded_ = true;
159
+ data_.state_[0] = seed & 0xffffffff;
160
+ for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
161
+ data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
162
+ }
163
+ data_.left_ = 1;
164
+ data_.next_ = 0;
165
+ }
166
+
167
+ inline uint32_t mix_bits(uint32_t u, uint32_t v) {
168
+ return (u & UMASK) | (v & LMASK);
169
+ }
170
+
171
+ inline uint32_t twist(uint32_t u, uint32_t v) {
172
+ return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
173
+ }
174
+
175
+ inline void next_state() {
176
+ uint32_t* p = data_.state_.data();
177
+ data_.left_ = MERSENNE_STATE_N;
178
+ data_.next_ = 0;
179
+
180
+ for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
181
+ *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
182
+ }
183
+
184
+ for(int j = MERSENNE_STATE_M; --j; p++) {
185
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
186
+ }
187
+
188
+ *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
189
+ }
190
+
191
+ };
192
+
193
+ typedef mt19937_engine mt19937;
194
+
195
+ } // namespace at
196
+
197
+ #else
198
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
199
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Dimname.h>
5
+ #include <c10/core/TensorImpl.h>
6
+
7
+ namespace at {
8
+
9
+ class TensorBase;
10
+
11
+ // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
12
+ // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
13
+ // so we have a couple of workarounds.
14
+ //
15
+ // In the long term, we'll move Dimname to c10 and everything in this file
16
+ // can be refactored out. The main blocker for that is that "c10::Symbol"
17
+ // actually exists outside of c10 and needs to be moved in.
18
+
19
+ // TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
20
+ // XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl.
21
+ //
22
+ // This class has an important invariant: there must be at least ONE
23
+ // non-wildcard
24
+ struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
25
+ // This enum is to remind people that the invariant on constructors is that
26
+ // the list of dimnames must have at least one non-wildcard
27
+ enum HAS_NON_WILDCARD {
28
+ HasNonWildcard
29
+ };
30
+
31
+ explicit NamedTensorMeta(HAS_NON_WILDCARD /*unused*/, DimnameList names)
32
+ : names_(names.vec()) {
33
+ check_invariants();
34
+ }
35
+ explicit NamedTensorMeta(HAS_NON_WILDCARD /*unused*/, std::vector<Dimname>&& names)
36
+ : names_(std::move(names)) {
37
+ check_invariants();
38
+ }
39
+
40
+ std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
41
+ return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
42
+ }
43
+
44
+ DimnameList names() const { return names_; }
45
+
46
+ // Used for an assertion in TensorImpl.h
47
+ int64_t slow_dim() const override {
48
+ return static_cast<int64_t>(names_.size());
49
+ }
50
+
51
+ void check_invariants() const {
52
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
53
+ std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
54
+ }
55
+
56
+ void set_names(HAS_NON_WILDCARD /*unused*/, DimnameList new_names) {
57
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
58
+ std::copy(new_names.begin(), new_names.end(), names_.begin());
59
+ check_invariants();
60
+ }
61
+
62
+ void set_names(HAS_NON_WILDCARD /*unused*/, std::vector<Dimname>&& new_names) {
63
+ TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
64
+ names_ = std::move(new_names);
65
+ check_invariants();
66
+ }
67
+
68
+ // INVARIANT: at least one Dimname is non-WILDCARD
69
+ std::vector<Dimname> names_;
70
+ };
71
+
72
+ // When NamesMode is disabled, then all operations ignore tensors' names fields.
73
+ // Concretely speaking, all tensors are treated as having nullopt names.
74
+ struct TORCH_API NamesMode {
75
+ static bool is_enabled();
76
+ static void set_enabled(bool enabled);
77
+ };
78
+
79
+
80
+ // A RAII, thread local (!) guard that enables or disables names upon
81
+ // construction, and sets it back to the original value upon destruction.
82
+ struct TORCH_API NoNamesGuard {
83
+ NoNamesGuard() : prev_mode(NamesMode::is_enabled()) {
84
+ NamesMode::set_enabled(false);
85
+ }
86
+ NoNamesGuard(const NoNamesGuard&) = delete;
87
+ NoNamesGuard(NoNamesGuard&&) = delete;
88
+ NoNamesGuard& operator=(const NoNamesGuard&) = delete;
89
+ NoNamesGuard& operator=(NoNamesGuard&&) = delete;
90
+ ~NoNamesGuard() {
91
+ if (initialized) {
92
+ reset();
93
+ }
94
+ }
95
+ void reset() {
96
+ TORCH_INTERNAL_ASSERT(initialized);
97
+ NamesMode::set_enabled(prev_mode);
98
+ }
99
+ private:
100
+ bool prev_mode;
101
+ bool initialized{true};
102
+ };
103
+
104
+ void check_names_valid_for(const TensorBase& tensor, DimnameList names);
105
+ void check_names_valid_for(size_t tensor_dim, DimnameList names);
106
+
107
+ // Sets the names of `tensor` to be `names`.
108
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names);
109
+ TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
110
+
111
+ constexpr size_t kMaxNamedTensorDim = 64;
112
+
113
+ DimnameList default_names(size_t len);
114
+
115
+ namespace impl {
116
+
117
+ // Some helper functions on TensorImpl. Useful for working with names in TH.
118
+ // XXX: Ideally these would exist as methods on TensorImpl
119
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names);
120
+ TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
121
+
122
+ void check_names_valid_for(TensorImpl* impl, DimnameList names);
123
+
124
+ // Returns true if the tensor's names exist and are not all 'None'.
125
+ // Returns false if the tensor's names don't exist (were not allocated),
126
+ // or if all names are 'None'.
127
+ // We treat not-allocated-names the same as allocated names that are all 'None'.
128
+ TORCH_API bool has_names(const TensorImpl* impl);
129
+
130
+ // Returns the names of the tensor's dimensions.
131
+ // Unnamed tensors are treated as having 'None' in all dimension; this method
132
+ // would return a DimnameList of all 'None's for an unnamed tensor.
133
+ TORCH_API DimnameList get_names(const TensorImpl* impl);
134
+
135
+ // This is more of an implementation detail; one should use impl::get_names /
136
+ // Tensor::names() whenever possible because it provides a cleaner API.
137
+ // Returns the names of the tensor if they have been allocated; returns nullopt
138
+ // instead if the haven't been. The names of a tensor are not allocated if a
139
+ // tensor is constructed with names=None.
140
+ TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl);
141
+
142
+ } // namespace impl
143
+
144
+ } // namespace at
145
+
146
+ #else
147
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
148
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/ConstantSymNodeImpl.h>
5
+ #include <c10/core/SymNodeImpl.h>
6
+ #include <c10/macros/Export.h>
7
+ #include <c10/util/Exception.h>
8
+ #include <c10/util/intrusive_ptr.h>
9
+ #include <cstdint>
10
+ #include <optional>
11
+ #include <string>
12
+
13
+ namespace c10 {
14
+
15
+ // The motivating usecase for this is to represent the ragged size structure
16
+ // of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
17
+ // allows us to simply return [B, j0, D] if someone queries for the size of our
18
+ // tensor.
19
+ //
20
+ // Morally we define comparison between two nested ints to return true if
21
+ // that comparison holds for all corresponding elements of the arrays they
22
+ // represent. Comparison between a nested int and a plain int is defined
23
+ // similarly.
24
+ //
25
+ // To simulate this desired behavior but also avoid the O(N) cost of checking,
26
+ // we associate each raggedness pattern with an integer "id" that can be used as
27
+ // a proxy to evaluate equality. We also constrain the range of values for this
28
+ // as to enable inequality checks.
29
+ //
30
+ // We also support a positive integer scalar "coeff" that is used for computing
31
+ // strides. For example given, a [B, j0, D] tensor, it can be strided in two
32
+ // different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
33
+ // differentiate the two cases.
34
+ //
35
+ // During tracing the strides of the outputs need to be a function of the size
36
+ // and strides of the inputs so it is important that NestedIntSymNode itself is
37
+ // able to express this.
38
+ class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
39
+ public:
40
+ // CAUTION: you should probably not be constructing these directly; please
41
+ // the higher-level API in python instead (TODO: actually introduce that).
42
+ explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
43
+ : val_(val), coeff_(coeff) {}
44
+
45
+ bool bool_() override {
46
+ return false;
47
+ }
48
+
49
+ bool is_int() override {
50
+ return true;
51
+ }
52
+
53
+ bool is_float() override {
54
+ return false;
55
+ }
56
+
57
+ bool is_bool() override {
58
+ return false;
59
+ }
60
+
61
+ bool is_nested_int() const override {
62
+ return true;
63
+ }
64
+
65
+ bool has_hint() override {
66
+ return true;
67
+ }
68
+
69
+ c10::SymNode wrap_int(int64_t num) override {
70
+ return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
71
+ }
72
+
73
+ int64_t guard_int(const char* file, int64_t line) override {
74
+ TORCH_CHECK(false);
75
+ }
76
+
77
+ double guard_float(const char* file, int64_t line) override {
78
+ TORCH_CHECK(false, "not a float");
79
+ }
80
+
81
+ bool guard_bool(const char* file, int64_t line) override {
82
+ TORCH_CHECK(false, "not a bool");
83
+ }
84
+
85
+ int64_t int_() override {
86
+ TORCH_CHECK(false);
87
+ }
88
+
89
+ std::string str() override {
90
+ if (coeff_ == 1) {
91
+ return "j" + std::to_string(val_);
92
+ }
93
+ return std::to_string(coeff_) + "*j" + std::to_string(val_);
94
+ }
95
+
96
+ // NOTE [ Inequalities with nested int ]
97
+ //
98
+ // The semantics of nested int when it comes to relations is that it is
99
+ // treated as integer known to be within a certain range,
100
+ //
101
+ // j0 \in [2, int64_t::max]
102
+ //
103
+ // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
104
+ // This is a useful default range for the raggedness pattern of a jagged
105
+ // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
106
+ // specialization checks.
107
+ //
108
+ // [ Indeterminate inequalities error out ]
109
+ //
110
+ // Given the semantic defined above, certain relations like j0 < 3 are thus
111
+ // indeterminable. In our impl today, evaluating such relations error
112
+ //
113
+ // It may seem convenient to just define indeterminate relations to return
114
+ // False, but the implementation we maintain in parallel using sympy does not
115
+ // allow this.
116
+ //
117
+ // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
118
+ // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
119
+ // would mean that means that if we define the indeterminate j0 >= 3 to be
120
+ // False, the also indeterminate j0 < 3 will be evaluated to be True!
121
+ //
122
+ // [ Coefficient are assumed positive ]
123
+ //
124
+ // For the purpose of computing inequalities, we consider the coefficient of
125
+ // the nested int to be a positive integer.
126
+ //
127
+ // Thus, no modifications are needed to the logic since
128
+ // j0 >= k implies coeff * j0 >= k
129
+ //
130
+ c10::SymNode eq(const c10::SymNode& other) override;
131
+ c10::SymNode ne(const c10::SymNode& other) override;
132
+ c10::SymNode ge(const c10::SymNode& other) override;
133
+ c10::SymNode gt(const c10::SymNode& other) override;
134
+ c10::SymNode lt(const c10::SymNode& other) override;
135
+ c10::SymNode le(const c10::SymNode& other) override;
136
+ c10::SymNode mul(const c10::SymNode& other) override;
137
+
138
+ std::optional<int64_t> nested_int() override {
139
+ return val_;
140
+ }
141
+
142
+ std::optional<int64_t> nested_int_coeff() override {
143
+ return coeff_;
144
+ }
145
+
146
+ bool is_symbolic() override {
147
+ return false;
148
+ }
149
+
150
+ c10::SymNode clone() override;
151
+
152
+ #define DEFINE_BINARY_NOT_SUPPORTED(name) \
153
+ c10::SymNode name(const c10::SymNode& other) override { \
154
+ TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
155
+ }
156
+
157
+ DEFINE_BINARY_NOT_SUPPORTED(add)
158
+ DEFINE_BINARY_NOT_SUPPORTED(sub)
159
+ DEFINE_BINARY_NOT_SUPPORTED(truediv)
160
+ DEFINE_BINARY_NOT_SUPPORTED(pow)
161
+ DEFINE_BINARY_NOT_SUPPORTED(floordiv)
162
+ DEFINE_BINARY_NOT_SUPPORTED(mod)
163
+ DEFINE_BINARY_NOT_SUPPORTED(sym_min)
164
+ DEFINE_BINARY_NOT_SUPPORTED(sym_max)
165
+ DEFINE_BINARY_NOT_SUPPORTED(sym_and)
166
+ DEFINE_BINARY_NOT_SUPPORTED(sym_or)
167
+
168
+ #undef DEFINE_BINARY_NOT_SUPPORTED
169
+
170
+ #define DEFINE_NOT_SUPPORTED(name) \
171
+ c10::SymNode name() override { \
172
+ TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
173
+ }
174
+
175
+ DEFINE_NOT_SUPPORTED(sym_not)
176
+ DEFINE_NOT_SUPPORTED(ceil)
177
+ DEFINE_NOT_SUPPORTED(floor)
178
+ DEFINE_NOT_SUPPORTED(neg)
179
+ DEFINE_NOT_SUPPORTED(sym_float)
180
+
181
+ #undef DEFINE_NOT_SUPPORTED
182
+
183
+ private:
184
+ int64_t val_;
185
+ int64_t coeff_;
186
+ };
187
+
188
+ } // namespace c10
189
+
190
+ #else
191
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
192
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // define constants like M_PI and C keywords for MSVC
5
+ #ifdef _MSC_VER
6
+ #define _USE_MATH_DEFINES
7
+ #include <math.h>
8
+ #endif
9
+
10
+
11
+ #ifdef __CUDACC__
12
+ #include <cuda.h>
13
+ #endif
14
+
15
+ #include <array>
16
+ #include <c10/macros/Macros.h>
17
+ #include <cmath>
18
+ #include <cstdint>
19
+
20
+ namespace at {
21
+
22
+ // typedefs for holding vector data
23
+ namespace detail {
24
+
25
+ typedef std::array<uint32_t, 4> UINT4;
26
+ typedef std::array<uint32_t, 2> UINT2;
27
+ typedef std::array<double, 2> DOUBLE2;
28
+ typedef std::array<float, 2> FLOAT2;
29
+
30
+ } // namespace detail
31
+
32
+ /**
33
+ * Note [Philox Engine implementation]
34
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
35
+ * Originally implemented in PyTorch's fusion compiler
36
+ * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
37
+ * for details regarding the engine.
38
+ *
39
+ * Note that currently this implementation of the philox engine is not used
40
+ * anywhere except for tests in cpu_generator_test.cpp. However, this engine
41
+ * will replace curandStatePhilox4_32_10_t in the future.
42
+ *
43
+ * The philox engine takes a seed value, a subsequeunce
44
+ * for starting the generation and an offset for the subsequence.
45
+ * Think of this engine as an algorithm producing a huge array. We are
46
+ * parallelizing this array by partitioning the huge array and assigning
47
+ * a thread index to each partition. In other words, each seed value
48
+ * (there are 2^64 possible seed values) gives a sub array of size
49
+ * 2^128 (each element in that array is a 128 bit number). Reasoning
50
+ * behind the array being of size 2^128 is, there are 2^64 possible
51
+ * thread index value and there is an array of size 2^64 for each of
52
+ * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
53
+ *
54
+ * In short, this generator can produce 2^64 (seed values) * 2^128 (number
55
+ * of elements in an array given by a seed value) = 2^192 values.
56
+ *
57
+ * Arguments:
58
+ * seed: Seed values could be any number from 0 to 2^64-1.
59
+ * subsequence: Subsequence is just the cuda thread indexing with:
60
+ * - blockIdx.x * blockDim.x + threadIdx.x
61
+ * offset: The offset variable in PhiloxEngine decides how many 128-bit
62
+ * random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
63
+ * and hence really decides the total number of randoms that can be achieved
64
+ * for the given subsequence.
65
+ */
66
+
67
+ class philox_engine {
68
+ public:
69
+
70
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
71
+ C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
72
+ uint64_t subsequence = 0,
73
+ uint64_t offset = 0) {
74
+
75
+ reset_state(seed, subsequence);
76
+ incr_n(offset);
77
+ }
78
+
79
+ C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
80
+ uint64_t subsequence = 0) {
81
+ key_[0] = static_cast<uint32_t>(seed);
82
+ key_[1] = static_cast<uint32_t>(seed >> 32);
83
+ counter_ = detail::UINT4{};
84
+ counter_[2] = static_cast<uint32_t>(subsequence);
85
+ counter_[3] = static_cast<uint32_t>(subsequence >> 32);
86
+ STATE = 0;
87
+ }
88
+
89
+ /**
90
+ * Set the offset field of Philox Generator to the desired offset.
91
+ */
92
+ C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
93
+ counter_[0] = static_cast<uint32_t>(offset);
94
+ counter_[1] = static_cast<uint32_t>(offset >> 32);
95
+ }
96
+
97
+ /**
98
+ * Gets the current offset of the Philox Generator.
99
+ */
100
+ C10_HOST_DEVICE uint64_t get_offset() const {
101
+ uint64_t lo = static_cast<uint64_t>(counter_[0]);
102
+ uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
103
+ return lo | hi;
104
+ }
105
+
106
+ /**
107
+ * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
108
+ */
109
+ C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
110
+ if(STATE == 0) {
111
+ detail::UINT4 counter = counter_;
112
+ detail::UINT2 key = key_;
113
+ output_ = rand(counter, key, n_rounds);
114
+ incr();
115
+ }
116
+ uint32_t ret = output_[static_cast<int>(STATE)];
117
+ STATE = (STATE + 1) & 3;
118
+ return ret;
119
+ }
120
+
121
+ inline float randn(uint32_t n_rounds) {
122
+ #ifdef __CUDA_ARCH__
123
+ AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
124
+ #endif
125
+ if(STATE == 0) {
126
+ detail::UINT4 counter = counter_;
127
+ detail::UINT2 key = key_;
128
+ output_ = rand(counter, key, n_rounds);
129
+ incr();
130
+ }
131
+ // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
132
+ // TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
133
+ float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
134
+ float u2 = 1 - uint32_to_uniform_float(output_[1]);
135
+ return static_cast<float>(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
136
+ }
137
+
138
+ /**
139
+ * Function that Skips N 128 bit numbers in a subsequence
140
+ */
141
+ C10_HOST_DEVICE inline void incr_n(uint64_t n) {
142
+ uint32_t nlo = static_cast<uint32_t>(n);
143
+ uint32_t nhi = static_cast<uint32_t>(n >> 32);
144
+ counter_[0] += nlo;
145
+ // if overflow in x has occurred, carry over to nhi
146
+ if (counter_[0] < nlo) {
147
+ nhi++;
148
+ // if overflow in nhi has occurred during carry over,
149
+ // propagate that overflow to y and exit to increment z
150
+ // otherwise return
151
+ counter_[1] += nhi;
152
+ if(nhi != 0) {
153
+ if (nhi <= counter_[1]) {
154
+ return;
155
+ }
156
+ }
157
+ } else {
158
+ // if overflow in y has occurred during addition,
159
+ // exit to increment z
160
+ // otherwise return
161
+ counter_[1] += nhi;
162
+ if (nhi <= counter_[1]) {
163
+ return;
164
+ }
165
+ }
166
+ if (++counter_[2])
167
+ return;
168
+ ++counter_[3];
169
+ }
170
+
171
+ /**
172
+ * Function that Skips one 128 bit number in a subsequence
173
+ */
174
+ C10_HOST_DEVICE inline void incr() {
175
+ if (++counter_[0])
176
+ return;
177
+ if (++counter_[1])
178
+ return;
179
+ if (++counter_[2]) {
180
+ return;
181
+ }
182
+ ++counter_[3];
183
+ }
184
+
185
+ private:
186
+ detail::UINT4 counter_;
187
+ detail::UINT4 output_;
188
+ detail::UINT2 key_;
189
+ uint32_t STATE;
190
+
191
+ C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
192
+ uint32_t *result_high) {
193
+ #ifdef __CUDA_ARCH__
194
+ *result_high = __umulhi(a, b);
195
+ return a*b;
196
+ #else
197
+ const uint64_t product = static_cast<uint64_t>(a) * b;
198
+ *result_high = static_cast<uint32_t>(product >> 32);
199
+ return static_cast<uint32_t>(product);
200
+ #endif
201
+ }
202
+
203
+ C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
204
+ uint32_t hi0 = 0;
205
+ uint32_t hi1 = 0;
206
+ uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
207
+ uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
208
+ detail::UINT4 ret;
209
+ ret[0] = hi1 ^ ctr[1] ^ in_key[0];
210
+ ret[1] = lo1;
211
+ ret[2] = hi0 ^ ctr[3] ^ in_key[1];
212
+ ret[3] = lo0;
213
+ return ret;
214
+ }
215
+
216
+ C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
217
+ // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
218
+ constexpr float scale = 4.6566127342e-10;
219
+ return static_cast<float>(value & 0x7FFFFFFF) * scale;
220
+ }
221
+
222
+
223
+
224
+ C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
225
+ for (uint32_t round = 0; round < (n_rounds - 1); round++) {
226
+ counter = single_round(counter, key);
227
+ key[0] += (kPhilox10A); key[1] += (kPhilox10B);
228
+ }
229
+ return single_round(counter, key);
230
+ }
231
+
232
+
233
+ static constexpr uint32_t kPhilox10A = 0x9E3779B9;
234
+ static constexpr uint32_t kPhilox10B = 0xBB67AE85;
235
+ static constexpr uint32_t kPhiloxSA = 0xD2511F53;
236
+ static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
237
+ };
238
+
239
+ typedef philox_engine Philox4_32;
240
+
241
+ } // namespace at
242
+
243
+ #else
244
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
245
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/core/TorchDispatchUtils.h>
4
+
5
+
6
+ namespace at::impl {
7
+
8
+ struct TORCH_API RestorePythonTLSSnapshot {
9
+ RestorePythonTLSSnapshot();
10
+ RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete;
11
+ RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete;
12
+ RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete;
13
+ RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete;
14
+ ~RestorePythonTLSSnapshot();
15
+
16
+ private:
17
+ c10::impl::LocalDispatchKeySet saved_;
18
+ c10::impl::ForceDispatchKeyGuard guard_;
19
+ };
20
+
21
+
22
+ // RAII guard to make working with the above TLS safer.
23
+ struct TORCH_API MaybeSetTLSOnEntryGuard {
24
+ public:
25
+ MaybeSetTLSOnEntryGuard();
26
+ MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete;
27
+ MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete;
28
+ MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete;
29
+ MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete;
30
+ ~MaybeSetTLSOnEntryGuard();
31
+
32
+ private:
33
+ bool value_set_;
34
+ };
35
+
36
+ } // namespace at::impl
37
+
38
+ #else
39
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
40
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/dispatch/Dispatcher.h>
5
+
6
+ // TODO: this can probably live in c10
7
+
8
+
9
+ namespace at::impl {
10
+
11
+ class TORCH_API PythonOpRegistrationTrampoline final {
12
+ static std::atomic<c10::impl::PyInterpreter*> interpreter_;
13
+
14
+ public:
15
+ // Returns true if you successfully registered yourself (that means
16
+ // you are in the hot seat for doing the operator registrations!)
17
+ static bool registerInterpreter(c10::impl::PyInterpreter* /*interp*/);
18
+
19
+ // Returns nullptr if no interpreter has been registered yet.
20
+ static c10::impl::PyInterpreter* getInterpreter();
21
+ };
22
+
23
+ } // namespace at::impl
24
+
25
+ #else
26
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
27
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/ScalarType.h>
5
+ #include <c10/core/QScheme.h>
6
+ #include <c10/util/intrusive_ptr.h>
7
+
8
+ namespace at {
9
+
10
+ class Tensor;
11
+ struct QTensorImpl;
12
+ struct Quantizer;
13
+ using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
14
+ using QuantizerPtr = c10::intrusive_ptr<Quantizer>;
15
+
16
+ /**
17
+ * Quantizer is the class for storing all the information
18
+ * that's necessary to perform quantize and dequantize
19
+ * operation.
20
+ *
21
+ * We might have different types of quantization schemes and this is
22
+ * the base class for all quantizers.
23
+ *
24
+ * QTensorImpl will hold a pointer to Quantizer so that we can support
25
+ * different quantization schemes on Tensor.
26
+ *
27
+ * For example, the most common quantization scheme, Affine Quantization,
28
+ * requires scale and zero_point as parameters, we'll store scale and zero_point
29
+ * inside the instance and we can use it to quantize a float Tensor or
30
+ * dequantize a quantized Tensor.
31
+ *
32
+ * When you add new types of leaf Quantizer class, please also
33
+ * make sure to add a corresponding QScheme enum since
34
+ * they should have one to one mapping.
35
+ *
36
+ * Note about intrusive_ptr:
37
+ * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
38
+ * share the same Quantizer. Quantizer should be immutable.
39
+ */
40
+ struct TORCH_API Quantizer : public c10::intrusive_ptr_target {
41
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
42
+ const ScalarType scalar_type_;
43
+ explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {}
44
+ ~Quantizer() override = default;
45
+
46
+ // Copied from torch/csrc/jit/ir/scope.h
47
+ QuantizerPtr intrusive_from_this() {
48
+ c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
49
+ // from a raw `this` pointer
50
+ // so we need to bump the refcount
51
+ // to account for this ownership
52
+ return c10::intrusive_ptr<Quantizer>::reclaim(this);
53
+ }
54
+
55
+ /**
56
+ * Each concrete Quantizer type should have a unique QScheme type.
57
+ */
58
+ virtual QScheme qscheme() const = 0;
59
+
60
+ ScalarType scalar_type() const {
61
+ return scalar_type_;
62
+ }
63
+
64
+ /**
65
+ * quantize a float Tensor into a quantized Tensor.
66
+ */
67
+ virtual Tensor quantize(const Tensor& t) = 0;
68
+
69
+ /**
70
+ * dequantize a quantized Tensor into a float Tensor.
71
+ */
72
+ virtual Tensor dequantize(const Tensor& t) = 0;
73
+
74
+ /**
75
+ * dequantize a quantized Tensor into a float Tensor, out= variant
76
+ */
77
+ virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0;
78
+
79
+ /**
80
+ * Compare against `other` for equality.
81
+ */
82
+ virtual bool equalTo(QuantizerPtr other) const = 0;
83
+ };
84
+
85
+ } // namespace at
86
+
87
+ #else
88
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
89
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+ #include <iosfwd>
6
+
7
+ namespace at {
8
+
9
+ struct Range {
10
+ Range(int64_t begin, int64_t end)
11
+ : begin(begin)
12
+ , end(end) {}
13
+
14
+ int64_t size() const { return end - begin; }
15
+
16
+ Range operator/(int64_t divisor) {
17
+ return Range(begin / divisor, end / divisor);
18
+ }
19
+
20
+ int64_t begin;
21
+ int64_t end;
22
+ };
23
+
24
+ std::ostream& operator<<(std::ostream& out, const Range& range);
25
+
26
+ } // namespace at
27
+
28
+ #else
29
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
30
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ namespace at::Reduction {
5
+
6
+ // NB: Keep this in sync with Reduction class in torch/nn/_reduction.py
7
+ // These constants control the reduction behavior of loss functions.
8
+ // Ideally, this would be a scoped enum, but jit doesn't support that
9
+ enum Reduction {
10
+ None, // Do not reduce
11
+ Mean, // (Possibly weighted) mean of losses
12
+ Sum, // Sum losses
13
+ END
14
+ };
15
+ } // namespace at::Reduction
16
+
17
+ #else
18
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
19
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/core/Scalar.h>
3
+
4
+ #else
5
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
6
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/core/ScalarType.h>
3
+
4
+ #else
5
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
6
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/TensorBody.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ namespace at {
8
+ // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
9
+ class TORCH_API OptionalTensorRef {
10
+ public:
11
+ OptionalTensorRef() = default;
12
+
13
+ ~OptionalTensorRef() {
14
+ ref_.unsafeReleaseTensorImpl();
15
+ }
16
+
17
+ OptionalTensorRef(const TensorBase& src)
18
+ : ref_(Tensor::unsafe_borrow_t{}, src) {
19
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
20
+ }
21
+
22
+ OptionalTensorRef(const OptionalTensorRef& rhs)
23
+ : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
24
+
25
+ OptionalTensorRef(OptionalTensorRef&& rhs) = default;
26
+ OptionalTensorRef& operator=(OptionalTensorRef rhs) {
27
+ std::swap(ref_, rhs.ref_);
28
+ return *this;
29
+ }
30
+
31
+ bool has_value() const {
32
+ return ref_.defined();
33
+ }
34
+
35
+ const Tensor& getTensorRef() const & {
36
+ return ref_;
37
+ }
38
+
39
+ const Tensor& operator*() const & {
40
+ return ref_;
41
+ }
42
+
43
+ const Tensor* operator->() const & {
44
+ return &ref_;
45
+ }
46
+
47
+ operator bool() const {
48
+ return ref_.defined();
49
+ }
50
+
51
+ private:
52
+ Tensor ref_;
53
+ };
54
+
55
+ // Use to convert a TensorBase (that may be undefined) to an at::Tensor
56
+ // without bumping refcount.
57
+ class TORCH_API TensorRef {
58
+ public:
59
+ ~TensorRef() {
60
+ ref_.unsafeReleaseTensorImpl();
61
+ }
62
+
63
+ TensorRef(const TensorBase& src)
64
+ : ref_(Tensor::unsafe_borrow_t{}, src) {}
65
+ TensorRef(TensorRef&& other) = default;
66
+ TensorRef(const TensorRef&) = default;
67
+ TensorRef& operator=(const TensorRef&) = default;
68
+ TensorRef& operator=(TensorRef&&) = default;
69
+
70
+ const Tensor& operator*() const & {
71
+ return ref_;
72
+ }
73
+ private:
74
+ Tensor ref_;
75
+ };
76
+
77
+ template <typename T>
78
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
79
+ // Return the grad argument in case of a hook with void return type to have an
80
+ // std::function with Tensor return type
81
+ static_assert(std::is_same_v<decltype(hook(Tensor())), void>,
82
+ "Expected hook to return void");
83
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
84
+ TensorRef grad(grad_base);
85
+ fn(*grad);
86
+ return Tensor();
87
+ });
88
+ }
89
+
90
+ template <typename T>
91
+ auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
92
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
93
+ TensorRef grad(grad_base);
94
+ Tensor ret = fn(*grad);
95
+ return TensorBase(std::move(ret));
96
+ });
97
+ }
98
+
99
+ } // namespace at
100
+
101
+ #else
102
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
103
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <torch/headeronly/core/TensorAccessor.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/ArrayRef.h>
7
+ #include <c10/util/Deprecated.h>
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/irange.h>
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+ #include <type_traits>
13
+
14
+ namespace at {
15
+
16
+ using torch::headeronly::DefaultPtrTraits;
17
+ #if defined(__CUDACC__) || defined(__HIPCC__)
18
+ using torch::headeronly::RestrictPtrTraits;
19
+ #endif
20
+
21
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
22
+ using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
23
+
24
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
25
+ using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
26
+
27
+ namespace detail {
28
+
29
+ template <size_t N, typename index_t>
30
+ struct IndexBoundsCheck {
31
+ IndexBoundsCheck(index_t i) {
32
+ TORCH_CHECK_INDEX(
33
+ 0 <= i && i < index_t{N},
34
+ "Index ",
35
+ i,
36
+ " is not within bounds of a tensor of dimension ",
37
+ N);
38
+ }
39
+ };
40
+ } // namespace detail
41
+
42
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
43
+ using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
44
+
45
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
46
+ using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
47
+
48
+ // Can't put this directly into the macro function args because of commas
49
+ #define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
50
+
51
+ // Old name for `GenericPackedTensorAccessor`
52
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
53
+ C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)
54
+
55
+ #undef AT_X
56
+
57
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
58
+ using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;
59
+
60
+ template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
61
+ using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
62
+ } // namespace at
63
+
64
+ #else
65
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
66
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // See https://github.com/pytorch/pytorch/issues/161660
5
+ // This compile flag is intended to be passed in to CppExtensions that rely on
6
+ // the stable ABI via the `extra_compile_args` argument. This is a stopgap
7
+ // solution to ensure that non-stable libtorch APIs are not used in the extension.
8
+ // The long term solution is to have a torch_stable target that excludes headers
9
+ // that are not in torch/stable or torch/headeronly.
10
+ // See test/cpp_extensions/torch_stable_test_extension/setup.py for an example
11
+ // of how this is used.
12
+ #ifdef TORCH_STABLE_ONLY
13
+ #error \
14
+ "TensorBase.h should not be included when TORCH_STABLE_ONLY compile flag is passed"
15
+ #endif
16
+
17
+ #include <c10/core/Device.h>
18
+ #include <c10/core/Layout.h>
19
+ #include <c10/core/MemoryFormat.h>
20
+ #include <c10/core/ScalarType.h>
21
+ #include <c10/core/ScalarTypeToTypeMeta.h>
22
+ #include <c10/core/Storage.h>
23
+ #include <c10/core/SymIntArrayRef.h>
24
+ #include <c10/core/TensorImpl.h>
25
+ #include <c10/core/TensorOptions.h>
26
+ #include <c10/core/UndefinedTensorImpl.h>
27
+ #include <c10/core/WrapDimMinimal.h>
28
+ #include <c10/util/C++17.h>
29
+ #include <c10/util/Exception.h>
30
+ #include <c10/util/ExclusivelyOwned.h>
31
+ #include <c10/util/ExclusivelyOwnedTensorTraits.h>
32
+ #include <c10/util/MaybeOwned.h>
33
+ #include <optional>
34
+ #include <c10/util/intrusive_ptr.h>
35
+
36
+ #include <ATen/core/NamedTensor.h>
37
+ #include <ATen/core/QuantizerBase.h>
38
+ #include <ATen/core/TensorAccessor.h>
39
+ #include <ATen/StorageUtils.h>
40
+
41
+ namespace c10 {
42
+ class Scalar;
43
+ }
44
+
45
+ namespace torch::autograd {
46
+
47
+ struct Node;
48
+
49
+ } // namespace torch::autograd
50
+
51
+ namespace at {
52
+
53
+ class Tensor;
54
+ class TensorBase;
55
+
56
+ // Convert Tensor to TensorBase without any need to include Tensor.h
57
+ TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
58
+
59
+ namespace impl {
60
+ inline bool variable_excluded_from_dispatch() {
61
+ #ifdef C10_MOBILE
62
+ // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
63
+ return true;
64
+ #else
65
+ return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
66
+ #endif
67
+ }
68
+
69
+ }
70
+
71
+ // NOTE: [Tensor vs. TensorBase]
72
+ //
73
+ // Tensor, being the central data structure in PyTorch, gets used and
74
+ // its header included almost everywhere. Unfortunately this means
75
+ // every time an operator signature is updated or changed in
76
+ // native_functions.yaml, you (and every other PyTorch developer) need
77
+ // to recompile all of ATen and its dependencies.
78
+ //
79
+ // TensorBase aims to break up these header dependencies, and improve
80
+ // incremental build times for all PyTorch developers. TensorBase
81
+ // represents a reference counted handle to TensorImpl, exactly the
82
+ // same as Tensor. However, TensorBase doesn't have code generated
83
+ // methods in its API and thus no dependence on native_functions.yaml.
84
+ //
85
+ // Usage tips
86
+ // ----------
87
+ // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
88
+ // or .cu file to ensure it has no header dependencies on
89
+ // native_functions.yaml (direct or indirect).
90
+ // - Tensor inherits from TensorBase, so functions taking
91
+ // `const TensorBase &` are callable with Tensor as well.
92
+ // - TensorBase can be converted to Tensor with `Tensor(tensor_base)`,
93
+ // but this requires a reference-count bump. OptionalTensorRef, on
94
+ // the other hand, can materialize a `const Tensor &` without
95
+ // touching the reference-count.
96
+ class TORCH_API TensorBase {
97
+ public:
98
+ struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
99
+
100
+ protected:
101
+ // Create a Tensor with a +0 reference count. Special care must be
102
+ // taken to avoid decrementing this reference count at destruction
103
+ // time. Intended to support MaybeOwnedTraits<Tensor>.
104
+ explicit TensorBase(unsafe_borrow_t /*unused*/, const TensorBase& rhs)
105
+ : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {}
106
+ friend MaybeOwnedTraits<TensorBase>;
107
+
108
+ public:
109
+ TensorBase() = default;
110
+ // This constructor should not be used by end users and is an implementation
111
+ // detail invoked by autogenerated code.
112
+ explicit TensorBase(
113
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
114
+ : impl_(std::move(tensor_impl)) {
115
+ TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
116
+ }
117
+ TensorBase(const TensorBase&) = default;
118
+ TensorBase(TensorBase&&) noexcept = default;
119
+ ~TensorBase() noexcept = default;
120
+
121
+ public:
122
+ // Creates a new wrapper from TensorImpl. Intentionally a free method because
123
+ // it should be used with care. Checks necessary invariants
124
+ static TensorBase wrap_tensor_impl(
125
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
126
+ TensorBase r(std::move(tensor_impl));
127
+ r.enforce_invariants();
128
+ return r;
129
+ }
130
+
131
+ int64_t dim() const {
132
+ return impl_->dim();
133
+ }
134
+ int64_t storage_offset() const {
135
+ return impl_->storage_offset();
136
+ }
137
+
138
+ TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
139
+ if (is_contiguous_or_false(memory_format)) {
140
+ return *this;
141
+ } else {
142
+ return __dispatch_contiguous(memory_format);
143
+ }
144
+ }
145
+
146
+ /// Should be used if *this can reasonably be expected to be contiguous and
147
+ /// performance is important.
148
+ /// Compared to contiguous, it saves a reference count
149
+ /// increment/decrement if *this is already contiguous, at the cost
150
+ /// in all cases of an extra pointer of stack usage, an extra branch
151
+ /// to access, and an extra branch at destruction time.
152
+ c10::MaybeOwned<TensorBase> expect_contiguous(
153
+ MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
154
+
155
+ // Use .contiguous() instead. Trying to borrow from a prvalue
156
+ // will only lead to trouble and dangling references.
157
+ c10::MaybeOwned<TensorBase> expect_contiguous(
158
+ MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
159
+
160
+ const TensorBase& fill_(const c10::Scalar& scalar) const;
161
+ const TensorBase& zero_() const;
162
+
163
+ TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
164
+
165
+ bool is_complex() const {
166
+ return at::isComplexType(this->scalar_type());
167
+ }
168
+
169
+ bool is_floating_point() const {
170
+ return at::isFloatingType(this->scalar_type());
171
+ }
172
+
173
+ bool is_signed() const {
174
+ return at::isSignedType(this->scalar_type());
175
+ }
176
+
177
+ c10::SymInt sym_size(int64_t dim) const {
178
+ return impl_->sym_size(dim);
179
+ }
180
+
181
+ c10::SymInt sym_stride(int64_t dim) const {
182
+ const auto sizes = this->sym_strides();
183
+ const auto ndim = static_cast<int64_t>(sizes.size());
184
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
185
+ return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
186
+
187
+ }
188
+
189
+ int64_t size(int64_t dim) const {
190
+ return impl_->size(dim);
191
+ }
192
+
193
+ int64_t stride(int64_t dim) const {
194
+ const auto strides = this->strides();
195
+ const auto ndim = static_cast<int64_t>(strides.size());
196
+ // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
197
+ return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
198
+ }
199
+
200
+ TensorImpl * unsafeGetTensorImpl() const {
201
+ return impl_.get();
202
+ }
203
+ TensorImpl * unsafeReleaseTensorImpl() {
204
+ return impl_.release();
205
+ }
206
+ const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
207
+ return impl_;
208
+ }
209
+
210
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
211
+ return std::move(impl_);
212
+ }
213
+
214
+ bool defined() const {
215
+ return impl_;
216
+ }
217
+
218
+ void reset() {
219
+ impl_.reset();
220
+ }
221
+
222
+ #if defined (_MSC_VER)
223
+ TensorBase& operator=(const TensorBase& x) & {
224
+ impl_ = x.impl_;
225
+ return *this;
226
+ };
227
+ TensorBase& operator=(TensorBase&& x) & noexcept {
228
+ impl_ = std::move(x.impl_);
229
+ return *this;
230
+ }
231
+ #else
232
+ TensorBase& operator=(const TensorBase& x) & = default;
233
+ TensorBase& operator=(TensorBase&& x) & noexcept = default;
234
+ #endif
235
+
236
+ // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
237
+ TensorBase& operator=(const TensorBase&) && = delete;
238
+ TensorBase& operator=(TensorBase&&) && noexcept = delete;
239
+
240
+ bool is_same(const TensorBase& other) const noexcept {
241
+ return impl_ == other.impl_;
242
+ }
243
+ size_t use_count() const noexcept {
244
+ return impl_.use_count();
245
+ }
246
+ size_t weak_use_count() const noexcept {
247
+ return impl_.weak_use_count();
248
+ }
249
+ bool is_uniquely_owned() const noexcept {
250
+ return impl_.is_uniquely_owned();
251
+ }
252
+
253
+ std::string toString() const;
254
+
255
+ IntArrayRef sizes() const {
256
+ return impl_->sizes();
257
+ }
258
+ c10::SymIntArrayRef sym_sizes() const {
259
+ return impl_->sym_sizes();
260
+ }
261
+ c10::SymIntArrayRef sym_strides() const {
262
+ return impl_->sym_strides();
263
+ }
264
+ IntArrayRef strides() const {
265
+ return impl_->strides();
266
+ }
267
+ // See impl::get_opt_names in ATen/NamedTensor.h for docs.
268
+ std::optional<DimnameList> opt_names() const {
269
+ return impl::get_opt_names(unsafeGetTensorImpl());
270
+ }
271
+ // See impl::get_names in ATen/NamedTensor.h for docs.
272
+ DimnameList names() const {
273
+ return impl::get_names(unsafeGetTensorImpl());
274
+ }
275
+ int64_t ndimension() const {
276
+ return dim();
277
+ }
278
+
279
+ bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
280
+ return impl_->is_contiguous(memory_format);
281
+ }
282
+
283
+ // Like is_contiguous, but more dynamic shape-friendly. May return a symbolic representation of
284
+ // contiguity instead of SymTrue SymFalse, when results are data-dependent.
285
+ c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
286
+ if (impl_->has_symbolic_sizes_strides()) {
287
+ return impl_->sym_is_contiguous(memory_format);
288
+ }
289
+ return impl_->is_contiguous(memory_format);
290
+ }
291
+
292
+ // Like is_contiguous, but more dynamic shape-friendly. Can returns
293
+ // false instead of throwing data-dependent errors for tensors with unbacked
294
+ // sizes or strides.
295
+ bool is_contiguous_or_false(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
296
+ if (impl_->has_symbolic_sizes_strides()) {
297
+ return impl_->sym_is_contiguous(memory_format).guard_or_false(__FILE__, __LINE__);
298
+ }
299
+ return impl_->is_contiguous(memory_format);
300
+ }
301
+
302
+ bool is_non_overlapping_and_dense() const {
303
+ return impl_->is_non_overlapping_and_dense();
304
+ }
305
+
306
+ at::MemoryFormat suggest_memory_format(
307
+ bool channels_last_strides_exact_match = false) const {
308
+ // Setting channels_last_strides_exact_match to true forces function to
309
+ // check 0,1 - sized dimension strides.
310
+ if (layout() == at::kStrided) {
311
+ if (impl_->is_strides_like_channels_last()) {
312
+ if (!channels_last_strides_exact_match ||
313
+ get_channels_last_strides_2d(sizes()) == strides()) {
314
+ return at::MemoryFormat::ChannelsLast;
315
+ }
316
+ }
317
+ else if (impl_->is_strides_like_channels_last_3d()) {
318
+ if (!channels_last_strides_exact_match ||
319
+ get_channels_last_strides_3d(sizes()) == strides()) {
320
+ return at::MemoryFormat::ChannelsLast3d;
321
+ }
322
+ }
323
+ }
324
+ return at::MemoryFormat::Contiguous;
325
+ }
326
+
327
+ // Total bytes consumed by the "view" of elements of the array. Does not
328
+ // include size of metadata. The number reported here does not necessarily
329
+ // correspond to the true physical memory consumed by a tensor; instead,
330
+ // it reports the memory the tensor would take *if* it were contiguous.
331
+ // Defined to be numel() * itemsize()
332
+ size_t nbytes() const {
333
+ TORCH_CHECK(layout () != at::kSparse,
334
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
335
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
336
+ "equivalent dense tensor, multiply numel() by element_size()");
337
+ return impl_->numel() * impl_->itemsize();
338
+ }
339
+
340
+ c10::SymInt sym_nbytes() const {
341
+ TORCH_CHECK(layout () != at::kSparse,
342
+ "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
343
+ "tensors, add the nbytes of the indices and values. If you want the size of the " \
344
+ "equivalent dense tensor, multiply numel() by element_size()");
345
+ return impl_->sym_numel() * impl_->itemsize();
346
+ }
347
+
348
+ int64_t numel() const {
349
+ return impl_->numel();
350
+ }
351
+
352
+ c10::SymInt sym_numel() const {
353
+ return impl_->sym_numel();
354
+ }
355
+
356
+ c10::SymInt sym_storage_offset() const {
357
+ return impl_->sym_storage_offset();
358
+ }
359
+
360
+ // Length of one array element in bytes. This is the traditional
361
+ // Numpy naming.
362
+ size_t itemsize() const {
363
+ return impl_->itemsize();
364
+ }
365
+
366
+ // Same as itemsize(). This is the PyTorch naming.
367
+ int64_t element_size() const {
368
+ return static_cast<int64_t>(impl_->itemsize());
369
+ }
370
+
371
+ DispatchKeySet key_set() const {
372
+ return impl_->key_set();
373
+ }
374
+ ScalarType scalar_type() const {
375
+ return typeMetaToScalarType(impl_->dtype());
376
+ }
377
+ bool has_storage() const {
378
+ return defined() && impl_->has_storage();
379
+ }
380
+ const Storage& storage() const {
381
+ return impl_->storage();
382
+ }
383
+ bool is_alias_of(const at::TensorBase& other) const{
384
+ return impl_->storage().is_alias_of(other.storage());
385
+ }
386
+
387
+ // Move the storage backend to shm based
388
+ // to enable memory sharing across processes.
389
+ //
390
+ // NB1: the ideal behavior of this API still requires further discussion
391
+ // but for now we are inclined to keep it consistent with existing THP behavior
392
+ // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
393
+ // so we don't assert on anything here and rely on caller knowing
394
+ // what it's doing.
395
+ //
396
+ // NB2: this currently provides Linux fd based shm support only
397
+ // to simplify the storage lifetime management logic in ATen
398
+ // and similarly for now we are not adding support for file system based
399
+ // shm support like in THP due to additional GC manager support needed
400
+ // to prevent leaks.
401
+ // As such, calling this from non supported systems (e.g. Windows) would fail.
402
+ void share_memory_() {
403
+ at::share_memory_(*this);
404
+ }
405
+
406
+ inline bool _is_zerotensor() const {
407
+ return impl_->_is_zerotensor();
408
+ }
409
+
410
+ inline void _set_zero(bool zero) const {
411
+ impl_->_set_zero(zero);
412
+ }
413
+
414
+ inline bool is_conj() const {
415
+ return impl_->is_conj();
416
+ }
417
+
418
+ // sets the conjugate bit of a tensor.
419
+ // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
420
+ // that's what you want. Changing this might lead to incorrect behavior since conjugation is
421
+ // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
422
+ inline void _set_conj(bool conjugate) const {
423
+ impl_->_set_conj(conjugate);
424
+ }
425
+
426
+ inline bool is_neg() const {
427
+ return impl_->is_neg();
428
+ }
429
+
430
+ // sets the negative bit of a tensor.
431
+ // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
432
+ // that's what you want. Changing this might lead to incorrect behavior since we rely on this
433
+ // bit to determine if a negation needs to be materialized.
434
+ inline void _set_neg(bool negative) const {
435
+ impl_->_set_neg(negative);
436
+ }
437
+
438
+ /// Returns a `Tensor`'s layout.
439
+ Layout layout() const {
440
+ return impl_->layout();
441
+ }
442
+
443
+ /// Returns a `Tensor`'s dtype (`TypeMeta`).
444
+ caffe2::TypeMeta dtype() const {
445
+ return impl_->dtype();
446
+ }
447
+
448
+ /// Returns a `Tensor`'s device.
449
+ inline Device device() const {
450
+ return impl_->device();
451
+ }
452
+
453
+ /// Returns a `Tensor`'s device index.
454
+ DeviceIndex get_device() const {
455
+ // NB: this is not a native function to avoid dispatching overhead.
456
+ return impl_->get_device();
457
+ }
458
+
459
+ /// Returns if a `Tensor` has CPU backend.
460
+ bool is_cpu() const {
461
+ // NB: this is not a native function to avoid dispatching overhead.
462
+ return impl_->is_cpu();
463
+ }
464
+
465
+ /// Returns if a `Tensor` has CUDA backend.
466
+ bool is_cuda() const {
467
+ // NB: this is not a native function to avoid dispatching overhead.
468
+ return impl_->is_cuda();
469
+ }
470
+
471
+ /// Returns if a `Tensor` has IPU backend.
472
+ bool is_ipu() const {
473
+ // NB: this is not a native function to avoid dispatching overhead.
474
+ return impl_->is_ipu();
475
+ }
476
+
477
+ /// Returns if a `Tensor` has XPU backend.
478
+ bool is_xpu() const {
479
+ // NB: this is not a native function to avoid dispatching overhead.
480
+ return impl_->is_xpu();
481
+ }
482
+
483
+ /// Returns if a `Tensor` has XLA backend.
484
+ bool is_xla() const {
485
+ return impl_->is_xla();
486
+ }
487
+
488
+ /// Returns if a `Tensor` has MTIA backend.
489
+ bool is_mtia() const {
490
+ return impl_->is_mtia();
491
+ }
492
+
493
+ /// Returns if a `Tensor` has HPU backend.
494
+ bool is_hpu() const {
495
+ return impl_->is_hpu();
496
+ }
497
+
498
+ /// Returns if a `Tensor` has Lazy backend.
499
+ bool is_lazy() const {
500
+ return impl_->is_lazy();
501
+ }
502
+
503
+ /// Returns if a `Tensor` has HIP backend.
504
+ bool is_hip() const {
505
+ // NB: this is not a native function to avoid dispatching overhead.
506
+ return impl_->is_hip();
507
+ }
508
+
509
+ /// Returns if a `Tensor` has VE backend.
510
+ bool is_ve() const {
511
+ // NB: this is not a native function to avoid dispatching overhead.
512
+ return impl_->is_ve();
513
+ }
514
+
515
+ /// Returns if a `Tensor` has PrivateUse1 backend.
516
+ bool is_privateuseone() const {
517
+ // NB: this is not a native function to avoid dispatching overhead.
518
+ return impl_->is_privateuseone();
519
+ }
520
+
521
+ /// Returns if a `Tensor` has sparse backend.
522
+ bool is_sparse() const {
523
+ // NB: this is not a native function to avoid dispatching overhead.
524
+ return impl_->is_sparse();
525
+ }
526
+
527
+ /// Returns is a `Tensor` has a sparse CSR backend.
528
+ bool is_sparse_csr() const {
529
+ // NB: this is not a native function to avoid dispatching overhead.
530
+ return impl_->is_sparse_csr();
531
+ }
532
+
533
+ /// Returns if a `Tensor` is mkldnn tensor.
534
+ bool is_mkldnn() const {
535
+ // NB: this is not a native function to avoid dispatching overhead.
536
+ return impl_->is_mkldnn();
537
+ }
538
+
539
+ /// Returns if a `Tensor` is mps tensor.
540
+ bool is_mps() const {
541
+ // NB: this is not a native function to avoid dispatching overhead.
542
+ return impl_->is_mps();
543
+ }
544
+
545
+ /// Returns if a `Tensor` is maia tensor.
546
+ bool is_maia() const {
547
+ // NB: this is not a native function to avoid dispatching overhead.
548
+ return impl_->is_maia();
549
+ }
550
+
551
+ /// Returns if a `Tensor` is vulkan tensor.
552
+ bool is_vulkan() const {
553
+ // NB: this is not a native function to avoid dispatching overhead.
554
+ return impl_->is_vulkan();
555
+ }
556
+
557
+ /// Returns if a `Tensor` is metal tensor.
558
+ bool is_metal() const {
559
+ // NB: this is not a native function to avoid dispatching overhead.
560
+ return impl_->is_metal();
561
+ }
562
+
563
+ /// Returns if a `Tensor` has quantized backend.
564
+ bool is_quantized() const {
565
+ // NB: this is not a native function to avoid dispatching overhead.
566
+ return impl_->is_quantized();
567
+ }
568
+
569
+ /// Returns if a `Tensor` is a meta tensor. Meta tensors can
570
+ /// also have other designations.
571
+ bool is_meta() const {
572
+ return impl_->is_meta();
573
+ }
574
+
575
+ /// Returns if a `Tensor` is an inference tensor.
576
+ bool is_inference() const {
577
+ return impl_->is_inference();
578
+ }
579
+
580
+ // Returns if a `Tensor` is a NestedTensor.
581
+ bool is_nested() const {
582
+ return impl_->is_nested();
583
+ }
584
+
585
+ /// If a tensor is a quantized tensor, returns its quantizer
586
+ /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
587
+ QuantizerPtr quantizer() const;
588
+
589
+ /// Returns if a `Tensor` has any dimension names
590
+ bool has_names() const {
591
+ // If a user is using unnamed tensors, then we can short-circuit right here.
592
+ // Otherwise, impl::has_names attempts to retrieve names.
593
+ if (!impl_->has_named_tensor_meta()) {
594
+ return false;
595
+ }
596
+ return impl::has_names(unsafeGetTensorImpl());
597
+ }
598
+
599
+ /// Returns a `Tensor`'s dimension names data structure
600
+ const NamedTensorMeta* get_named_tensor_meta() const {
601
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
602
+ }
603
+
604
+ NamedTensorMeta* get_named_tensor_meta() {
605
+ return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
606
+ }
607
+
608
+ /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
609
+ /// TensorOptions.h.
610
+ TensorOptions options() const {
611
+ return TensorOptions().dtype(dtype())
612
+ .device(device())
613
+ .layout(layout());
614
+ }
615
+
616
+ const void* const_data_ptr() const {
617
+ return this->unsafeGetTensorImpl()->data();
618
+ }
619
+
620
+ void* mutable_data_ptr() const {
621
+ return this->unsafeGetTensorImpl()->mutable_data();
622
+ }
623
+
624
+ // TODO(#97856) Make this return a const pointer. This currently
625
+ // returns a non-const pointer because of the large
626
+ // number of clients that we still want to audit before
627
+ // migrating to mutable_data_ptr().
628
+ void* data_ptr() const {
629
+ return mutable_data_ptr();
630
+ }
631
+
632
+ template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
633
+ const T* const_data_ptr() const;
634
+
635
+ template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
636
+ const std::remove_const_t<T>* const_data_ptr() const;
637
+
638
+ template <typename T>
639
+ T* mutable_data_ptr() const;
640
+
641
+ // Legacy interface during the migration to indicate that a callsite
642
+ // has not been audited for mutability.
643
+ //
644
+ // Do not add new uses of this, use const_data_ptr() if possible,
645
+ // mutable_data_ptr() otherwise.
646
+ //
647
+ // TODO(#97856) Make this return a const pointer. This is currently
648
+ // const because of the vast number of clients that
649
+ // rely on this.
650
+ template <typename T>
651
+ T* data_ptr() const;
652
+
653
+ // Purposely not defined here to avoid inlining
654
+ void print() const;
655
+
656
+ // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
657
+ // dimension.
658
+ template<typename T, size_t N>
659
+ TensorAccessor<T,N> accessor() const& {
660
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
661
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
662
+ T* ptr = nullptr;
663
+ if constexpr (std::is_const_v<T>) {
664
+ ptr = const_data_ptr<T>();
665
+ } else {
666
+ ptr = mutable_data_ptr<T>();
667
+ }
668
+ return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
669
+ }
670
+ template<typename T, size_t N>
671
+ TensorAccessor<T,N> accessor() && = delete;
672
+
673
+ // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
674
+ // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
675
+ // cast the data pointer to a __restrict__ pointer.
676
+ // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
677
+ // as an argument.
678
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
679
+ GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
680
+ static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
681
+ TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
682
+ T* ptr = nullptr;
683
+ if constexpr (std::is_const_v<T>) {
684
+ ptr = const_data_ptr<T>();
685
+ } else {
686
+ ptr = mutable_data_ptr<T>();
687
+ }
688
+ return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
689
+ }
690
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
691
+ GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
692
+
693
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
694
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
695
+ TORCH_CHECK(
696
+ impl_->numel() <=
697
+ static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
698
+ "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
699
+ return generic_packed_accessor<T,N,PtrTraits,int32_t>();
700
+ }
701
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
702
+ PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
703
+
704
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
705
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
706
+ return generic_packed_accessor<T,N,PtrTraits,int64_t>();
707
+ }
708
+ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
709
+ PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
710
+
711
+ // ~~~~~ Autograd API ~~~~~
712
+
713
+ /// \fn bool is_leaf() const;
714
+ ///
715
+ /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
716
+ ///
717
+ /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
718
+ /// created by the user. This means that they are not the result of an operation and so
719
+ /// `grad_fn()` is `nullptr`.
720
+ ///
721
+ /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
722
+ /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
723
+ ///
724
+ /// Example:
725
+ /// @code
726
+ /// auto a = torch::rand(10, torch::requires_grad());
727
+ /// std::cout << a.is_leaf() << std::endl; // prints `true`
728
+ ///
729
+ /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
730
+ /// std::cout << b.is_leaf() << std::endl; // prints `false`
731
+ /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
732
+ ///
733
+ /// auto c = torch::rand(10, torch::requires_grad()) + 2;
734
+ /// std::cout << c.is_leaf() << std::endl; // prints `false`
735
+ /// // c was created by the addition operation
736
+ ///
737
+ /// auto d = torch::rand(10).cuda();
738
+ /// std::cout << d.is_leaf() << std::endl; // prints `true`
739
+ /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
740
+ ///
741
+ /// auto e = torch::rand(10).cuda().requires_grad_();
742
+ /// std::cout << e.is_leaf() << std::endl; // prints `true`
743
+ /// // e requires gradients and has no operations creating it
744
+ ///
745
+ /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
746
+ /// std::cout << f.is_leaf() << std::endl; // prints `true`
747
+ /// // f requires grad, has no operation creating it
748
+ /// @endcode
749
+
750
+ /// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
751
+ ///
752
+ /// Computes the gradient of current tensor with respect to graph leaves.
753
+ ///
754
+ /// The graph is differentiated using the chain rule. If the tensor is
755
+ /// non-scalar (i.e. its data has more than one element) and requires
756
+ /// gradient, the function additionally requires specifying ``gradient``.
757
+ /// It should be a tensor of matching type and location, that contains
758
+ /// the gradient of the differentiated function w.r.t. this Tensor.
759
+ ///
760
+ /// This function accumulates gradients in the leaves - you might need to
761
+ /// zero them before calling it.
762
+ ///
763
+ /// \param gradient Gradient w.r.t. the
764
+ /// tensor. If it is a tensor, it will be automatically converted
765
+ /// to a Tensor that does not require grad unless ``create_graph`` is True.
766
+ /// None values can be specified for scalar Tensors or ones that
767
+ /// don't require grad. If a None value would be acceptable then
768
+ /// this argument is optional.
769
+ /// \param retain_graph If ``false``, the graph used to compute
770
+ /// the grads will be freed. Note that in nearly all cases setting
771
+ /// this option to True is not needed and often can be worked around
772
+ /// in a much more efficient way. Defaults to the value of
773
+ /// ``create_graph``.
774
+ /// \param create_graph If ``true``, graph of the derivative will
775
+ /// be constructed, allowing to compute higher order derivative
776
+ /// products. Defaults to ``false``.
777
+ /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
778
+ /// ``at::Tensor::grad``. All other Tensors will be ignored. If not
779
+ /// provided, the gradient is accumulated into all the leaf Tensors
780
+ /// that were used to compute the current tensor.
781
+ /// When inputs are provided and a given input is not a leaf,
782
+ /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
783
+ /// It is an implementation detail on which the user should not rely.
784
+ /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
785
+
786
+ /// \fn Tensor detach() const;
787
+ ///
788
+ /// Returns a new Tensor, detached from the current graph.
789
+ /// The result will never require gradient.
790
+
791
+ /// \fn Tensor & detach_() const;
792
+ ///
793
+ /// Detaches the Tensor from the graph that created it, making it a leaf.
794
+ /// Views cannot be detached in-place.
795
+
796
+ /// \fn void retain_grad() const;
797
+ ///
798
+ /// Enables this Tensor to have their :attr:`grad` populated during
799
+ /// :func:`backward`. This is a no-op for leaf tensors.
800
+
801
+ /// \fn bool retains_grad() const;
802
+ ///
803
+ /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
804
+ /// populated during :func:`backward`, ``false`` otherwise.
805
+
806
+ const TensorBase& set_requires_grad(bool requires_grad) const {
807
+ impl_->set_requires_grad(requires_grad);
808
+ return *this;
809
+ }
810
+ bool requires_grad() const {
811
+ return impl_->requires_grad();
812
+ }
813
+
814
+ // The Forward AD API functions below are low level and are not to be used by end
815
+ // users who should use the API provided in torch/csrc/autograd.h
816
+
817
+ /// This function returns the forward gradient for this Tensor at the given level.
818
+ const Tensor& _fw_grad(uint64_t level) const {
819
+ return impl_->_fw_grad(level, *this);
820
+ }
821
+
822
+ /// This function can be used to set the value of the forward grad.
823
+ /// Note that the given new_grad might not be used directly if it has different
824
+ /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
825
+ /// new_grad content will be copied into a new Tensor
826
+ void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
827
+ impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
828
+ }
829
+
830
+ /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
831
+ /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
832
+ /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
833
+ ///
834
+ /// One notable difference with the legacy `.data()` function is that changes to the
835
+ /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
836
+ /// will not update the original `Variable`, due to the fact that this function
837
+ /// shallow-copies the `Variable`'s underlying TensorImpl.
838
+ at::TensorBase tensor_data() const;
839
+
840
+ /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
841
+ /// in Python, which create a new `Variable` that shares the same storage and
842
+ /// tensor metadata with the original `Variable`, but with a completely new
843
+ /// autograd history.
844
+ ///
845
+ /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
846
+ /// storage / storage_offset) of a variable created from `var.variable_data()`, those
847
+ /// changes will not update the original variable `var`. In `.variable_data()`, we set
848
+ /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
849
+ /// in order to prevent users from changing metadata of `var.variable_data()`
850
+ /// and expecting the original variable `var` to also be updated.
851
+ at::TensorBase variable_data() const;
852
+
853
+ // Gradient Node and Edges
854
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
855
+
856
+ /// Gets the gradient function of the `Variable`. If this is a leaf variable,
857
+ /// the pointer returned will be null.
858
+ ///
859
+ /// For View Variables:
860
+ /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
861
+ /// re-create the grad_fn to express the up-to-date view relationship between
862
+ /// this and the base Variable.
863
+ const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
864
+
865
+ // Hooks
866
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
867
+
868
+ template <typename T>
869
+ using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
870
+ template <typename T>
871
+ using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
872
+
873
+ /// Registers a backward hook.
874
+ ///
875
+ /// The hook will be called every time a gradient with respect to the Tensor is computed.
876
+ /// The hook should have one of the following signature:
877
+ /// ```
878
+ /// hook(TensorBase grad) -> TensorBase
879
+ /// ```
880
+ /// ```
881
+ /// hook(TensorBase grad) -> void
882
+ /// ```
883
+ /// The hook should not modify its argument, but it can optionally return a new gradient
884
+ /// which will be used in place of `grad`.
885
+ ///
886
+ /// This function returns the index of the hook in the list which can be used to remove hook.
887
+ ///
888
+ /// Example:
889
+ /// @code
890
+ /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
891
+ /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
892
+ /// v.backward(torch::tensor({1., 2., 3.}));
893
+ /// // This prints:
894
+ /// // ```
895
+ /// // 2
896
+ /// // 4
897
+ /// // 6
898
+ /// // [ CPUFloatType{3} ]
899
+ /// // ```
900
+ /// std::cout << v.grad() << std::endl;
901
+ /// v.remove_hook(h); // removes the hook
902
+ /// @endcode
903
+ template <typename T>
904
+ hook_return_void_t<T> register_hook(T&& hook) const;
905
+ template <typename T>
906
+ hook_return_var_t<T> register_hook(T&& hook) const;
907
+
908
+ protected:
909
+ unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
910
+
911
+ public:
912
+
913
+ /// Remove hook at given position
914
+ void remove_hook(unsigned pos) const;
915
+
916
+ // Variable methods
917
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
918
+
919
+ bool is_leaf() const;
920
+
921
+ int64_t output_nr() const;
922
+
923
+ void set_data(const TensorBase & new_data) const;
924
+
925
+ TensorBase data() const;
926
+
927
+ int64_t _version() const;
928
+
929
+ void retain_grad() const;
930
+
931
+ bool retains_grad() const;
932
+
933
+ const TensorBase& requires_grad_(bool _requires_grad=true) const;
934
+
935
+ std::optional<ScalarType> grad_dtype() const;
936
+
937
+ void set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const;
938
+
939
+ // View Variables
940
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
941
+
942
+ /// Returns true if this `Variable` is a view of another `Variable`.
943
+ bool is_view() const;
944
+
945
+ /// Returns the `Variable` that this `Variable` is a view of. If this
946
+ /// `Variable` is not a view, throw a `std::runtime_error`.
947
+ const TensorBase& _base() const;
948
+
949
+ // Miscellaneous
950
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
951
+
952
+ const std::string& name() const;
953
+
954
+ protected:
955
+ void enforce_invariants();
956
+ c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
957
+
958
+ private:
959
+ TensorBase __dispatch_contiguous(c10::MemoryFormat /*memory_format*/) const;
960
+ };
961
+
962
+ inline DeviceIndex get_device(const TensorBase& self) {
963
+ return self.get_device();
964
+ }
965
+
966
+ template <typename T>
967
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
968
+ // Return the grad argument in case of a hook with void return type to have an
969
+ // std::function with Tensor return type
970
+ static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
971
+ "Expected hook to return void");
972
+ return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
973
+ fn(grad);
974
+ return TensorBase();
975
+ });
976
+ }
977
+
978
+ template <typename T>
979
+ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
980
+ return _register_hook(std::forward<T>(hook));
981
+ }
982
+
983
+ namespace detail {
984
+ // Helper creator for Tensor class which doesn't requires the users to pass
985
+ // in an intrusive_ptr instead it just converts the argument passed to
986
+ // requested intrusive_ptr type.
987
+ template <typename T, typename... Args>
988
+ TensorBase make_tensor_base(Args&&... args) {
989
+ return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
990
+ }
991
+
992
+ } // namespace detail
993
+
994
+ inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
995
+ return legacyExtractDispatchKey(t.key_set());
996
+ }
997
+
998
+ } // namespace at
999
+
1000
+ namespace c10 {
1001
+ template <>
1002
+ struct MaybeOwnedTraits<at::TensorBase> {
1003
+ using owned_type = at::TensorBase;
1004
+ using borrow_type = at::TensorBase;
1005
+
1006
+ static borrow_type createBorrow(const owned_type& from) {
1007
+ // NOTE: this can be implemented without the special
1008
+ // unsafe_borrow_t Tensor constructor as
1009
+ //
1010
+ // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
1011
+ //
1012
+ // but that hurts inlining due to the nullptr check in the
1013
+ // Tensor(c10::intrusive_ptr<...>) constructor. We already know
1014
+ // that from.impl_ isn't null because from is a valid Tensor, so
1015
+ // we needn't do the check again. (using __builtin_assume can
1016
+ // avoid this, but wouldn't be portable to MSVC.)
1017
+ return borrow_type(borrow_type::unsafe_borrow_t{}, from);
1018
+ }
1019
+
1020
+ static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
1021
+ lhs.unsafeReleaseTensorImpl();
1022
+ // See above note: this can be implemented with public API
1023
+ // similarly to createBorrow(), but that would hurt inlining.
1024
+ lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
1025
+ }
1026
+
1027
+ static void destroyBorrow(borrow_type& toDestroy) {
1028
+ toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
1029
+ }
1030
+
1031
+ static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
1032
+ return borrow;
1033
+ }
1034
+
1035
+ static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
1036
+ return &borrow;
1037
+ }
1038
+
1039
+ static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
1040
+ return true;
1041
+ }
1042
+ };
1043
+
1044
+ template <>
1045
+ struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
1046
+ } // namespace c10
1047
+
1048
+ namespace at {
1049
+
1050
+ inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
1051
+ const std::optional<TensorBase>& opt) {
1052
+ return opt.has_value()
1053
+ ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
1054
+ : c10::MaybeOwned<TensorBase>::owned(std::in_place);
1055
+ }
1056
+
1057
+ inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
1058
+ if (is_contiguous(memory_format)) {
1059
+ return c10::MaybeOwned<TensorBase>::borrowed(*this);
1060
+ } else {
1061
+ return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
1062
+ }
1063
+ }
1064
+
1065
+ namespace symint {
1066
+
1067
+ template <typename T>
1068
+ using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
1069
+ template <typename T>
1070
+ using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
1071
+
1072
+ template <typename T, typename = enable_if_symint<T>>
1073
+ c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
1074
+ template <typename T, typename = enable_if_int<T>>
1075
+ IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
1076
+
1077
+ template <typename T, typename = enable_if_symint<T>>
1078
+ c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
1079
+ template <typename T, typename = enable_if_int<T>>
1080
+ int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
1081
+
1082
+ template <typename T, typename = enable_if_symint<T>>
1083
+ c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
1084
+ template <typename T, typename = enable_if_int<T>>
1085
+ IntArrayRef strides(const TensorBase& t) { return t.strides(); }
1086
+
1087
+ template <typename T, typename = enable_if_symint<T>>
1088
+ c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
1089
+ template <typename T, typename = enable_if_int<T>>
1090
+ int64_t numel(const TensorBase& t) { return t.numel(); }
1091
+
1092
+ } // namespace symint
1093
+
1094
+ } // namespace at
1095
+
1096
+ #else
1097
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1098
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h ADDED
The diff for this file is too large to render. See raw diff
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/dispatch/Dispatcher.h>
5
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
6
+ #include <c10/util/ArrayRef.h>
7
+ #include <torch/library.h>
8
+ #include <optional>
9
+
10
+ namespace at::impl {
11
+
12
+ TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
13
+ TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
14
+ TORCH_API bool tensorlist_has_dispatch(
15
+ const c10::List<std::optional<at::Tensor>>& li);
16
+ using c10::impl::dispatch_mode_enabled;
17
+
18
+ } // namespace at::impl
19
+
20
+ #else
21
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
22
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <ATen/NumericUtils.h>
3
+ #include <c10/macros/Macros.h>
4
+ #include <c10/util/Half.h>
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/MathConstants.h>
7
+ #include <cmath>
8
+ #include <cstdint>
9
+ #include <cassert>
10
+ #include <limits>
11
+ #include <type_traits>
12
+
13
+ namespace at {
14
+
15
+ // Using DistAccumType in accumulate types for distributions.
16
+ // Note: Ideally we'd be using ATen/AccumulateType.h but looks
17
+ // like the there is some inconsistency in how accumulate types
18
+ // are mapped currently, e.g. for the cpu side, float is mapped
19
+ // to double.
20
+ template <typename T>
21
+ struct DistAccumType { };
22
+
23
+ #if defined(__CUDACC__) || defined(__HIPCC__)
24
+ template <> struct DistAccumType<half> { using type = float; };
25
+ #endif
26
+ template <> struct DistAccumType<BFloat16> { using type = float; };
27
+ template <> struct DistAccumType<Half> { using type = float; };
28
+ template <> struct DistAccumType<float> { using type = float; };
29
+ template <> struct DistAccumType<double> { using type = double; };
30
+
31
+ template <typename T>
32
+ using dist_acctype = typename DistAccumType<T>::type;
33
+
34
+ namespace transformation {
35
+
36
+ /**
37
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
38
+ * `range` is `to - from`
39
+ * `base` is `from`
40
+ */
41
+ template <typename T, typename V>
42
+ C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) {
43
+ return static_cast<T>(static_cast<int64_t>((val % range) + base));
44
+ }
45
+
46
+ /**
47
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
48
+ */
49
+ template <typename T, typename V>
50
+ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
51
+ return static_cast<T>(static_cast<int64_t>(val));
52
+ }
53
+
54
+ /**
55
+ * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
56
+ * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
57
+ * in this overloaded version
58
+ */
59
+ template <typename T, typename V>
60
+ C10_HOST_DEVICE inline std::enable_if_t<!std::is_floating_point_v<T>, T>uniform_int(V val) {
61
+ if constexpr (std::is_same_v<T, bool>) {
62
+ return static_cast<bool>(val & 1);
63
+ } else if constexpr (std::is_same_v<T, int64_t>) {
64
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
65
+ } else if constexpr (std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>) {
66
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
67
+ } else if constexpr (std::is_integral_v<T>) {
68
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
69
+ } else {
70
+ assert(false);
71
+ return 0;
72
+ }
73
+ }
74
+
75
+ /**
76
+ * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
77
+ * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
78
+ */
79
+ template<typename T, typename V>
80
+ C10_HOST_DEVICE inline std::enable_if_t<std::is_floating_point_v<T>, T>uniform_int(V val) {
81
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
82
+ }
83
+
84
+ template <typename T, typename V>
85
+ C10_HOST_DEVICE inline dist_acctype<T> uniform_real(V val, T from, T to) {
86
+ constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
87
+ constexpr auto DIVISOR = static_cast<dist_acctype<T>>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
88
+ dist_acctype<T> x = (val & MASK) * DIVISOR;
89
+ return (x * (to - from) + from);
90
+ }
91
+
92
+ /**
93
+ * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to
94
+ * normally distributed with `mean` and standard deviation `std`.
95
+ */
96
+ template <typename T>
97
+ C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
98
+ return val * std + mean;
99
+ }
100
+
101
+ /**
102
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
103
+ * Cauchy distribution with location parameter `median` and scale parameter `sigma`.
104
+ */
105
+ template <typename T>
106
+ C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
107
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
108
+ // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps),
109
+ // thus we clip those values.
110
+ constexpr T eps = std::numeric_limits<T>::epsilon();
111
+ constexpr T one_minus_eps = 1 - eps;
112
+ constexpr T zero_plus_eps = 0 + eps;
113
+ val = (val > one_minus_eps ? one_minus_eps : val);
114
+ val = (val < zero_plus_eps ? zero_plus_eps : val);
115
+ return median + sigma * at::tan(c10::pi<T> * (val - static_cast<T>(0.5)));
116
+ }
117
+
118
+ template <>
119
+ C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
120
+ // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
121
+ return median + sigma * at::tan(c10::pi<double> * (val - 0.5));
122
+ }
123
+
124
+ /**
125
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
126
+ * exponentially distributed with `lambda` parameter of the distribution.
127
+ */
128
+ template <typename T>
129
+ C10_HOST_DEVICE inline T exponential(T val, T lambda) {
130
+ // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
131
+ // Different implementations for CUDA and CPU to preserve original logic
132
+ // TODO: must be investigated and unified!!!
133
+ // https://github.com/pytorch/pytorch/issues/38662
134
+ #if defined(__CUDACC__) || defined(__HIPCC__)
135
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
136
+ // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
137
+ // we need log to be not 0, and not underflow when converted to half
138
+ // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
139
+ auto log = val >= static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2
140
+ ? -std::numeric_limits<T>::epsilon() / 2
141
+ : at::log(val);
142
+ return static_cast<T>(-1.0) / lambda * log;
143
+ #else
144
+ return static_cast<T>(-1.0) / lambda * at::log1p(-val);
145
+ #endif
146
+ }
147
+
148
+ /**
149
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
150
+ * geometrically distributed with success probability `p`.
151
+ */
152
+ template <typename T>
153
+ C10_HOST_DEVICE inline T geometric(T val, T p) {
154
+ // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions
155
+ return static_cast<T>(::ceil(at::log(val) / at::log1p(-p)));
156
+ }
157
+
158
+ /**
159
+ * Transforms normally distributed `val` to log-normally distributed.
160
+ */
161
+ template <typename T>
162
+ C10_HOST_DEVICE inline T log_normal(T val) {
163
+ // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles
164
+ return at::exp(val);
165
+ }
166
+
167
+ /**
168
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
169
+ * bernoulli distributed with success probability `p`.
170
+ */
171
+ template <typename T>
172
+ C10_HOST_DEVICE inline T bernoulli(T val, T p) {
173
+ return val < p;
174
+ }
175
+
176
+ }} // namespace at::transformation
177
+
178
+ #else
179
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
180
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/core/UndefinedTensorImpl.h>
3
+
4
+ #else
5
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
6
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ namespace at {
6
+
7
+ inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
8
+ auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
9
+ if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
10
+ c10::raw::intrusive_ptr::incref(tensor_impl.get());
11
+ }
12
+ return Tensor(std::move(tensor_impl));
13
+ }
14
+
15
+ inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
16
+ if (retain && th_pointer) {
17
+ c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
18
+ }
19
+ return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
20
+ }
21
+
22
+ }
23
+
24
+ #else
25
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
26
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+ #include <c10/macros/Export.h>
6
+
7
+ // A little explanation about why this file exists at all. We have
8
+ // a few methods on Tensor class which require access to reified access to
9
+ // AutogradMeta. In open source, this isn't a big deal: we just access
10
+ // torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
11
+ // we can put the definitions inline. This is because everything gets balled
12
+ // into a single dynamic library in the end.
13
+ //
14
+ // However, inside our Facebook internal version of our build system, we
15
+ // have a split between aten and torch/csrc. So we cannot simply just
16
+ // cross this boundary. "Now wait," you might say, "Why don't we just
17
+ // merge the libraries inside Facebook". Well, the problem is that there
18
+ // are some downstream applications which are at binary size limit, and
19
+ // incorporating all of the extra code from libtorch would push them
20
+ // over (admarket/adreview/service:adreviewservice, see also
21
+ // https://github.com/pytorch/pytorch/pull/29299) So if you want to do that,
22
+ // we have to fix all of the services like this.
23
+ //
24
+ // I didn't want to block eliminating Tensor-Variable on this work, so I
25
+ // had to introduce another dynamic dispatch to get to the variable
26
+ // implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
27
+ //
28
+ // I also considered using our existing dynamic dispatch mechanism, c10
29
+ // dispatcher, to do this. However, (1) some of the functions on Tensor
30
+ // have weird signatures that are not supported by autograd, and (2)
31
+ // see this bug https://github.com/pytorch/pytorch/issues/30102
32
+
33
+ namespace torch::autograd {
34
+
35
+ struct Node;
36
+
37
+ } // namespace torch::autograd
38
+
39
+ namespace at::impl {
40
+
41
+ struct TORCH_API VariableHooksInterface {
42
+ virtual ~VariableHooksInterface() = default;
43
+ virtual TensorBase tensor_data(const TensorBase&) const = 0;
44
+ virtual TensorBase variable_data(const TensorBase&) const = 0;
45
+ virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(
46
+ const TensorBase&) const = 0;
47
+ virtual unsigned _register_hook(
48
+ const TensorBase&,
49
+ std::function<TensorBase(const TensorBase&)> hook) const = 0;
50
+ virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
51
+ virtual bool is_view(const TensorBase&) const = 0;
52
+ virtual const TensorBase& base(const TensorBase&) const = 0;
53
+ virtual const std::string& name(const TensorBase&) const = 0;
54
+ virtual bool is_leaf(const TensorBase&) const = 0;
55
+ virtual int64_t output_nr(const TensorBase&) const = 0;
56
+ virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
57
+ virtual TensorBase data(const TensorBase&) const = 0;
58
+ virtual int64_t _version(const TensorBase&) const = 0;
59
+ virtual void retain_grad(const TensorBase&) const = 0;
60
+ virtual bool retains_grad(const TensorBase&) const = 0;
61
+ virtual void _backward(
62
+ const Tensor&,
63
+ TensorList,
64
+ const std::optional<Tensor>&,
65
+ std::optional<bool>,
66
+ bool) const = 0;
67
+ virtual void requires_grad_(const TensorBase&, bool) const = 0;
68
+ virtual void basic_autograd_not_implemented_fallback(
69
+ const c10::OperatorHandle& op,
70
+ c10::DispatchKeySet dispatch_keys,
71
+ torch::jit::Stack* stack) const = 0;
72
+ virtual std::optional<c10::ScalarType> grad_dtype(const TensorBase&) const = 0;
73
+ virtual void set_grad_dtype(const TensorBase&, const std::optional<c10::ScalarType>&) const = 0;
74
+ };
75
+
76
+ TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
77
+ TORCH_API VariableHooksInterface* GetVariableHooks();
78
+ TORCH_API bool HasVariableHooks();
79
+
80
+ struct TORCH_API VariableHooksRegisterer {
81
+ explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
82
+ SetVariableHooks(hooks);
83
+ }
84
+ };
85
+
86
+ } // namespace at::impl
87
+
88
+ #else
89
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
90
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <utility>
5
+
6
+ #include <c10/util/ArrayRef.h>
7
+ #include <ATen/core/List.h>
8
+
9
+ namespace at {
10
+
11
+ // This class allows you to write variadic functions which
12
+ // call a (possibly overloaded) function on each argument,
13
+ // in order. This is most commonly used in autogenerated code,
14
+ // where it is convenient to have a function that can uniformly
15
+ // take arguments of different types. If your arguments
16
+ // are homogeneous consider using a std::initializer_list instead.
17
+ //
18
+ // For examples of this in use, see torch/csrc/utils/variadic.h
19
+ template <typename F>
20
+ struct IterArgs {
21
+ template <typename... Args>
22
+ inline F& apply() {
23
+ return self();
24
+ }
25
+
26
+ // NB: Use perfect forwarding here, otherwise we'll make value
27
+ // copies of all arguments!
28
+ template <typename T, typename... Args>
29
+ inline F& apply(T&& arg, Args&&... args) {
30
+ self()(std::forward<T>(arg));
31
+ if (self().short_circuit()) {
32
+ return self();
33
+ } else {
34
+ return apply(std::forward<Args>(args)...);
35
+ }
36
+ }
37
+
38
+ // Here are some handy overloads which provide sensible
39
+ // defaults for container-like structures that one might
40
+ // be interested in recursing into. You can enable them
41
+ // by adding:
42
+ //
43
+ // using IterArgs<YourStructName>::operator()
44
+ //
45
+ // to your struct. These are not enabled by default because
46
+ // you may be able to process these structures more efficiently
47
+ // than handling them one-by-one.
48
+
49
+ template <typename T>
50
+ void operator()(c10::IListRef<T> args) {
51
+ for (const auto& arg : args) {
52
+ self()(arg);
53
+ if (self().short_circuit())
54
+ return;
55
+ }
56
+ }
57
+
58
+ template <typename T>
59
+ void operator()(at::ArrayRef<T> args) {
60
+ for (const auto& arg : args) {
61
+ self()(arg);
62
+ if (self().short_circuit())
63
+ return;
64
+ }
65
+ }
66
+
67
+ template <typename T>
68
+ void operator()(const torch::List<T>& args) {
69
+ for (const auto& arg : args) {
70
+ self()(arg);
71
+ if (self().short_circuit())
72
+ return;
73
+ }
74
+ }
75
+
76
+ // NB: we need to specify std::vector manually as C++ won't
77
+ // do an implicit conversion to make a template deduction go through.
78
+ template <typename T>
79
+ void operator()(const std::vector<T>& args) {
80
+ self()(at::ArrayRef<T>{args});
81
+ }
82
+
83
+ constexpr bool short_circuit() const {
84
+ return false;
85
+ }
86
+
87
+ private:
88
+ inline F& self() {
89
+ return *static_cast<F*>(this);
90
+ }
91
+ };
92
+
93
+ } // namespace torch
94
+
95
+ #else
96
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
97
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ostream>
4
+ #include <sstream>
5
+ #include <unordered_map>
6
+
7
+ #include <c10/core/impl/LocalDispatchKeySet.h>
8
+
9
+ namespace at::vitals {
10
+
11
+ TORCH_API bool torchVitalEnabled();
12
+
13
+ struct TORCH_API TorchVitalAttr {
14
+ // always initialized to empty
15
+ std::string value;
16
+ template <typename T>
17
+ TorchVitalAttr& operator<<(const T& t) {
18
+ if (torchVitalEnabled()) {
19
+ std::stringstream ss;
20
+ ss << t;
21
+ value += ss.str();
22
+ }
23
+ return *this;
24
+ }
25
+
26
+ template <typename T>
27
+ void write(const T& t, bool force) {
28
+ if (force || torchVitalEnabled()) {
29
+ std::stringstream ss;
30
+ ss << t;
31
+ value = ss.str();
32
+ }
33
+ }
34
+ };
35
+
36
+ struct TORCH_API TorchVital {
37
+ std::string name;
38
+ std::unordered_map<std::string, TorchVitalAttr> attrs;
39
+
40
+ explicit TorchVital(std::string n) : name(std::move(n)) {}
41
+ TorchVital(const TorchVital&) = default;
42
+ TorchVital(TorchVital&&) = default;
43
+ TorchVital& operator=(const TorchVital&) = default;
44
+ TorchVital& operator=(TorchVital&&) = default;
45
+ TorchVital() = delete;
46
+
47
+ TorchVitalAttr& create(const std::string& attr);
48
+ TorchVitalAttr& create(const std::string& attr, bool force);
49
+ friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
50
+
51
+ ~TorchVital();
52
+ };
53
+
54
+ std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
55
+
56
+ // A way to access vitals by string names instead of by global reference.
57
+ // This enables access to vitals from the PythonAPI.
58
+ class TORCH_API APIVitals {
59
+ public:
60
+ bool vitals_enabled;
61
+
62
+ // Set any vital sign that was added to the map.
63
+ bool setVital(
64
+ const std::string& vital_name,
65
+ const std::string& attr_name,
66
+ const std::string& value,
67
+ bool force = false);
68
+ std::string readVitals();
69
+
70
+ APIVitals();
71
+
72
+ // Ensure this stays a singleton
73
+ APIVitals(APIVitals const& other) = delete;
74
+ APIVitals(APIVitals&& other) = delete;
75
+ APIVitals& operator=(const APIVitals&) = delete;
76
+ APIVitals& operator=(APIVitals&&) = delete;
77
+ ~APIVitals() = default;
78
+
79
+ private:
80
+ std::unordered_map<std::string, TorchVital> name_map_;
81
+ };
82
+
83
+ extern TORCH_API APIVitals VitalsAPI;
84
+
85
+ } // namespace at::vitals
86
+
87
+ #define TORCH_VITAL_DECLARE(name) \
88
+ TORCH_API at::vitals::TorchVital TorchVital_##name;
89
+
90
+ #define TORCH_VITAL_DEFINE(name) \
91
+ TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
92
+
93
+ #define TORCH_VITAL_BASE(name) TorchVital_##name
94
+
95
+ #define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
96
+
97
+ #else
98
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
99
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <set>
4
+ #include <string>
5
+ #include <unordered_set>
6
+ #include <vector>
7
+ #include <ATen/core/symbol.h>
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/hash.h>
10
+
11
+ namespace c10 {
12
+ /**
13
+ * class AliasInfo
14
+ *
15
+ * Data structure to hold aliasing information for an `Argument`. They can be
16
+ * nested to represent aliasing information on contained types.
17
+ *
18
+ * There is a `beforeSet` which describes the aliasing information before the
19
+ * operator executes, and an `afterSet` that describes aliasing info
20
+ * after execution.
21
+ */
22
+ class AliasInfo {
23
+ public:
24
+ AliasInfo() = default;
25
+ AliasInfo(bool is_write, const std::set<std::string>& before_qual_strings, const std::set<std::string>& after_qual_strings) : isWrite_(is_write) {
26
+ for (const auto& s: before_qual_strings) {
27
+ beforeSets_.insert(Symbol::fromQualString(s));
28
+ }
29
+ for (const auto& s : after_qual_strings) {
30
+ afterSets_.insert(Symbol::fromQualString(s));
31
+ }
32
+ }
33
+ // Symbol for the set that can alias anything
34
+ static Symbol wildcardSet() {
35
+ static const Symbol wc = Symbol::fromQualString("alias::*");
36
+ return wc;
37
+ }
38
+
39
+ void setIsWrite(bool isWrite) {
40
+ isWrite_ = isWrite;
41
+ }
42
+
43
+ bool isWrite() const {
44
+ return isWrite_;
45
+ }
46
+
47
+ void addBeforeSet(Symbol aliasSet) {
48
+ beforeSets_.insert(aliasSet);
49
+ }
50
+
51
+ void addAfterSet(Symbol aliasSet) {
52
+ afterSets_.insert(aliasSet);
53
+ }
54
+
55
+ const std::unordered_set<Symbol>& beforeSets() const {
56
+ return beforeSets_;
57
+ }
58
+
59
+ const std::unordered_set<Symbol>& afterSets() const {
60
+ return afterSets_;
61
+ }
62
+
63
+ Symbol beforeSet() const {
64
+ AT_ASSERT(beforeSets_.size() == 1);
65
+ return *beforeSets_.begin();
66
+ }
67
+
68
+ bool isWildcardBefore() const {
69
+ return beforeSets_.count(wildcardSet()) != 0;
70
+ }
71
+
72
+ bool isWildcardAfter() const {
73
+ return afterSets_.count(wildcardSet()) != 0;
74
+ }
75
+
76
+ // the alias info for the contained types of the type
77
+ // e.g. if this is an annotation on List[T], `sets` refers to
78
+ // the alias sets that the list may be in
79
+ // while containedTypes()[0] refers to the sets that members of the list
80
+ // may be in
81
+ void addContainedType(AliasInfo aliasInfo) {
82
+ containedTypes_.push_back(std::move(aliasInfo));
83
+ }
84
+ const std::vector<AliasInfo>& containedTypes() const {
85
+ return containedTypes_;
86
+ }
87
+
88
+ private:
89
+ std::unordered_set<Symbol> beforeSets_;
90
+ std::unordered_set<Symbol> afterSets_;
91
+ std::vector<AliasInfo> containedTypes_;
92
+ bool isWrite_ = false;
93
+ };
94
+
95
+ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
96
+ return lhs.isWrite() == rhs.isWrite()
97
+ && lhs.beforeSets() == rhs.beforeSets()
98
+ && lhs.afterSets() == rhs.afterSets()
99
+ && lhs.containedTypes() == rhs.containedTypes();
100
+ }
101
+
102
+ // this does match the way things are represented in the schema
103
+ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
104
+ out << '(';
105
+ bool first = true;
106
+ for (const auto& set : aliasInfo.beforeSets()) {
107
+ if (first) {
108
+ first = false;
109
+ } else {
110
+ out << '|';
111
+ }
112
+ out << set.toUnqualString();
113
+ }
114
+ if (aliasInfo.isWrite()) {
115
+ out << '!';
116
+ }
117
+ if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
118
+ out << " -> ";
119
+ first = true;
120
+ for (const auto& set : aliasInfo.afterSets()) {
121
+ if (first) {
122
+ first = false;
123
+ } else {
124
+ out << '|';
125
+ }
126
+ out << set.toUnqualString();
127
+ }
128
+ }
129
+ out << ')';
130
+ return out;
131
+ }
132
+ } // namespace c10
133
+
134
+ namespace std {
135
+ template <>
136
+ struct hash<c10::AliasInfo> {
137
+ size_t operator()(const c10::AliasInfo& aliasInfo) const {
138
+ auto hash = std::hash<bool>()(aliasInfo.isWrite());
139
+
140
+ // NOTE: for unordered_set hashes, we couldn't use hash_combine
141
+ // because hash_combine is order dependent. Instead, we choose to
142
+ // use XOR as the combining function as XOR is commutative.
143
+ size_t before_set_hash_seed = 0;
144
+ for (auto &e: aliasInfo.beforeSets()) {
145
+ auto symbol_hash = std::hash<c10::Symbol>()(e);
146
+ before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
147
+ }
148
+ size_t after_set_hash_seed = 0;
149
+ for (auto &e: aliasInfo.afterSets()) {
150
+ auto symbol_hash = std::hash<c10::Symbol>()(e);
151
+ after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
152
+ }
153
+
154
+ hash = c10::hash_combine(hash, before_set_hash_seed);
155
+ hash = c10::hash_combine(hash, after_set_hash_seed);
156
+ for (auto &e: aliasInfo.containedTypes()) {
157
+ auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
158
+ hash = c10::hash_combine(hash, contained_type_hash);
159
+ }
160
+ return hash;
161
+ }
162
+ };
163
+ }
164
+
165
+ #else
166
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
167
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h ADDED
@@ -0,0 +1,2309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from aten_interned_strings.h
5
+
6
+ #if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
7
+ #error This change adds a dependency on native_functions.yaml, \
8
+ meaning the file will need to be re-compiled every time an operator \
9
+ is changed or added. Consider if including <ATen/core/symbol.h> for \
10
+ the c10::Symbol class would be sufficient, or if your change would be \
11
+ better placed in another file.
12
+ #endif
13
+
14
+ // ATen symbols correspond exactly to operators defined in ATen. Every
15
+ // symbol here corresponds exactly to an ATen operation defined in
16
+ // native_functions.yaml; attributes are in one-to-one correspondence
17
+ // with their ATen name.
18
+
19
+ #define FORALL_ATEN_BASE_SYMBOLS(_) \
20
+ _(aten, __and__) \
21
+ _(aten, __iand__) \
22
+ _(aten, __ilshift__) \
23
+ _(aten, __ior__) \
24
+ _(aten, __irshift__) \
25
+ _(aten, __ixor__) \
26
+ _(aten, __lshift__) \
27
+ _(aten, __or__) \
28
+ _(aten, __rshift__) \
29
+ _(aten, __xor__) \
30
+ _(aten, _adaptive_avg_pool2d) \
31
+ _(aten, _adaptive_avg_pool2d_backward) \
32
+ _(aten, _adaptive_avg_pool3d) \
33
+ _(aten, _adaptive_avg_pool3d_backward) \
34
+ _(aten, _add_batch_dim) \
35
+ _(aten, _add_relu) \
36
+ _(aten, _add_relu_) \
37
+ _(aten, _addmm_activation) \
38
+ _(aten, _aminmax) \
39
+ _(aten, _amp_foreach_non_finite_check_and_unscale) \
40
+ _(aten, _amp_foreach_non_finite_check_and_unscale_) \
41
+ _(aten, _amp_update_scale) \
42
+ _(aten, _amp_update_scale_) \
43
+ _(aten, _assert_async) \
44
+ _(aten, _assert_scalar) \
45
+ _(aten, _assert_tensor_metadata) \
46
+ _(aten, _autocast_to_full_precision) \
47
+ _(aten, _autocast_to_reduced_precision) \
48
+ _(aten, _backward) \
49
+ _(aten, _batch_norm_impl_index) \
50
+ _(aten, _batch_norm_impl_index_backward) \
51
+ _(aten, _batch_norm_no_update) \
52
+ _(aten, _batch_norm_with_update) \
53
+ _(aten, _batch_norm_with_update_functional) \
54
+ _(aten, _cast_Byte) \
55
+ _(aten, _cast_Char) \
56
+ _(aten, _cast_Double) \
57
+ _(aten, _cast_Float) \
58
+ _(aten, _cast_Half) \
59
+ _(aten, _cast_Int) \
60
+ _(aten, _cast_Long) \
61
+ _(aten, _cast_Short) \
62
+ _(aten, _cdist_backward) \
63
+ _(aten, _cdist_forward) \
64
+ _(aten, _cholesky_solve_helper) \
65
+ _(aten, _choose_qparams_per_tensor) \
66
+ _(aten, _chunk_cat) \
67
+ _(aten, _coalesce) \
68
+ _(aten, _coalesced) \
69
+ _(aten, _coalesced_) \
70
+ _(aten, _compute_linear_combination) \
71
+ _(aten, _conj) \
72
+ _(aten, _conj_copy) \
73
+ _(aten, _conj_physical) \
74
+ _(aten, _conv_depthwise2d) \
75
+ _(aten, _convert_indices_from_coo_to_csr) \
76
+ _(aten, _convert_indices_from_csr_to_coo) \
77
+ _(aten, _convert_weight_to_int4pack) \
78
+ _(aten, _convert_weight_to_int4pack_for_cpu) \
79
+ _(aten, _convolution) \
80
+ _(aten, _convolution_double_backward) \
81
+ _(aten, _convolution_mode) \
82
+ _(aten, _copy_from) \
83
+ _(aten, _copy_from_and_resize) \
84
+ _(aten, _cslt_compress) \
85
+ _(aten, _cslt_sparse_mm) \
86
+ _(aten, _cslt_sparse_mm_search) \
87
+ _(aten, _ctc_loss) \
88
+ _(aten, _ctc_loss_backward) \
89
+ _(aten, _cudnn_attention_backward) \
90
+ _(aten, _cudnn_attention_forward) \
91
+ _(aten, _cudnn_ctc_loss) \
92
+ _(aten, _cudnn_init_dropout_state) \
93
+ _(aten, _cudnn_rnn) \
94
+ _(aten, _cudnn_rnn_backward) \
95
+ _(aten, _cudnn_rnn_flatten_weight) \
96
+ _(aten, _cufft_clear_plan_cache) \
97
+ _(aten, _cufft_get_plan_cache_max_size) \
98
+ _(aten, _cufft_get_plan_cache_size) \
99
+ _(aten, _cufft_set_plan_cache_max_size) \
100
+ _(aten, _cummax_helper) \
101
+ _(aten, _cummin_helper) \
102
+ _(aten, _debug_has_internal_overlap) \
103
+ _(aten, _dimI) \
104
+ _(aten, _dimV) \
105
+ _(aten, _dim_arange) \
106
+ _(aten, _dirichlet_grad) \
107
+ _(aten, _dyn_quant_matmul_4bit) \
108
+ _(aten, _dyn_quant_pack_4bit_weight) \
109
+ _(aten, _efficient_attention_backward) \
110
+ _(aten, _efficient_attention_forward) \
111
+ _(aten, _efficientzerotensor) \
112
+ _(aten, _embedding_bag) \
113
+ _(aten, _embedding_bag_backward) \
114
+ _(aten, _embedding_bag_dense_backward) \
115
+ _(aten, _embedding_bag_forward_only) \
116
+ _(aten, _embedding_bag_per_sample_weights_backward) \
117
+ _(aten, _embedding_bag_sparse_backward) \
118
+ _(aten, _empty_affine_quantized) \
119
+ _(aten, _empty_per_channel_affine_quantized) \
120
+ _(aten, _euclidean_dist) \
121
+ _(aten, _fake_quantize_learnable_per_channel_affine) \
122
+ _(aten, _fake_quantize_learnable_per_channel_affine_backward) \
123
+ _(aten, _fake_quantize_learnable_per_tensor_affine) \
124
+ _(aten, _fake_quantize_learnable_per_tensor_affine_backward) \
125
+ _(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \
126
+ _(aten, _fft_c2c) \
127
+ _(aten, _fft_c2r) \
128
+ _(aten, _fft_r2c) \
129
+ _(aten, _fill_mem_eff_dropout_mask) \
130
+ _(aten, _fill_mem_eff_dropout_mask_) \
131
+ _(aten, _flash_attention_backward) \
132
+ _(aten, _flash_attention_forward) \
133
+ _(aten, _foobar) \
134
+ _(aten, _foreach_abs) \
135
+ _(aten, _foreach_abs_) \
136
+ _(aten, _foreach_acos) \
137
+ _(aten, _foreach_acos_) \
138
+ _(aten, _foreach_add) \
139
+ _(aten, _foreach_add_) \
140
+ _(aten, _foreach_addcdiv) \
141
+ _(aten, _foreach_addcdiv_) \
142
+ _(aten, _foreach_addcmul) \
143
+ _(aten, _foreach_addcmul_) \
144
+ _(aten, _foreach_asin) \
145
+ _(aten, _foreach_asin_) \
146
+ _(aten, _foreach_atan) \
147
+ _(aten, _foreach_atan_) \
148
+ _(aten, _foreach_ceil) \
149
+ _(aten, _foreach_ceil_) \
150
+ _(aten, _foreach_clamp_max) \
151
+ _(aten, _foreach_clamp_max_) \
152
+ _(aten, _foreach_clamp_min) \
153
+ _(aten, _foreach_clamp_min_) \
154
+ _(aten, _foreach_copy) \
155
+ _(aten, _foreach_copy_) \
156
+ _(aten, _foreach_cos) \
157
+ _(aten, _foreach_cos_) \
158
+ _(aten, _foreach_cosh) \
159
+ _(aten, _foreach_cosh_) \
160
+ _(aten, _foreach_div) \
161
+ _(aten, _foreach_div_) \
162
+ _(aten, _foreach_erf) \
163
+ _(aten, _foreach_erf_) \
164
+ _(aten, _foreach_erfc) \
165
+ _(aten, _foreach_erfc_) \
166
+ _(aten, _foreach_exp) \
167
+ _(aten, _foreach_exp_) \
168
+ _(aten, _foreach_expm1) \
169
+ _(aten, _foreach_expm1_) \
170
+ _(aten, _foreach_floor) \
171
+ _(aten, _foreach_floor_) \
172
+ _(aten, _foreach_frac) \
173
+ _(aten, _foreach_frac_) \
174
+ _(aten, _foreach_lerp) \
175
+ _(aten, _foreach_lerp_) \
176
+ _(aten, _foreach_lgamma) \
177
+ _(aten, _foreach_lgamma_) \
178
+ _(aten, _foreach_log) \
179
+ _(aten, _foreach_log10) \
180
+ _(aten, _foreach_log10_) \
181
+ _(aten, _foreach_log1p) \
182
+ _(aten, _foreach_log1p_) \
183
+ _(aten, _foreach_log2) \
184
+ _(aten, _foreach_log2_) \
185
+ _(aten, _foreach_log_) \
186
+ _(aten, _foreach_max) \
187
+ _(aten, _foreach_maximum) \
188
+ _(aten, _foreach_maximum_) \
189
+ _(aten, _foreach_minimum) \
190
+ _(aten, _foreach_minimum_) \
191
+ _(aten, _foreach_mul) \
192
+ _(aten, _foreach_mul_) \
193
+ _(aten, _foreach_neg) \
194
+ _(aten, _foreach_neg_) \
195
+ _(aten, _foreach_norm) \
196
+ _(aten, _foreach_pow) \
197
+ _(aten, _foreach_pow_) \
198
+ _(aten, _foreach_reciprocal) \
199
+ _(aten, _foreach_reciprocal_) \
200
+ _(aten, _foreach_round) \
201
+ _(aten, _foreach_round_) \
202
+ _(aten, _foreach_rsqrt) \
203
+ _(aten, _foreach_rsqrt_) \
204
+ _(aten, _foreach_sigmoid) \
205
+ _(aten, _foreach_sigmoid_) \
206
+ _(aten, _foreach_sign) \
207
+ _(aten, _foreach_sign_) \
208
+ _(aten, _foreach_sin) \
209
+ _(aten, _foreach_sin_) \
210
+ _(aten, _foreach_sinh) \
211
+ _(aten, _foreach_sinh_) \
212
+ _(aten, _foreach_sqrt) \
213
+ _(aten, _foreach_sqrt_) \
214
+ _(aten, _foreach_sub) \
215
+ _(aten, _foreach_sub_) \
216
+ _(aten, _foreach_tan) \
217
+ _(aten, _foreach_tan_) \
218
+ _(aten, _foreach_tanh) \
219
+ _(aten, _foreach_tanh_) \
220
+ _(aten, _foreach_trunc) \
221
+ _(aten, _foreach_trunc_) \
222
+ _(aten, _foreach_zero) \
223
+ _(aten, _foreach_zero_) \
224
+ _(aten, _functional_assert_async) \
225
+ _(aten, _functional_assert_scalar) \
226
+ _(aten, _functional_sym_constrain_range) \
227
+ _(aten, _functional_sym_constrain_range_for_size) \
228
+ _(aten, _fused_adagrad) \
229
+ _(aten, _fused_adagrad_) \
230
+ _(aten, _fused_adam) \
231
+ _(aten, _fused_adam_) \
232
+ _(aten, _fused_adamw) \
233
+ _(aten, _fused_adamw_) \
234
+ _(aten, _fused_dropout) \
235
+ _(aten, _fused_moving_avg_obs_fq_helper) \
236
+ _(aten, _fused_moving_avg_obs_fq_helper_functional) \
237
+ _(aten, _fused_rms_norm) \
238
+ _(aten, _fused_rms_norm_backward) \
239
+ _(aten, _fused_sdp_choice) \
240
+ _(aten, _fused_sgd) \
241
+ _(aten, _fused_sgd_) \
242
+ _(aten, _fw_primal) \
243
+ _(aten, _fw_primal_copy) \
244
+ _(aten, _gather_sparse_backward) \
245
+ _(aten, _grid_sampler_2d_cpu_fallback) \
246
+ _(aten, _grid_sampler_2d_cpu_fallback_backward) \
247
+ _(aten, _grouped_mm) \
248
+ _(aten, _has_compatible_shallow_copy_type) \
249
+ _(aten, _has_same_storage_numel) \
250
+ _(aten, _histogramdd_bin_edges) \
251
+ _(aten, _histogramdd_from_bin_cts) \
252
+ _(aten, _histogramdd_from_bin_tensors) \
253
+ _(aten, _index_put_impl) \
254
+ _(aten, _index_put_impl_) \
255
+ _(aten, _indices) \
256
+ _(aten, _indices_copy) \
257
+ _(aten, _int_mm) \
258
+ _(aten, _is_all_true) \
259
+ _(aten, _is_any_true) \
260
+ _(aten, _is_zerotensor) \
261
+ _(aten, _jagged_to_padded_dense_forward) \
262
+ _(aten, _lazy_clone) \
263
+ _(aten, _linalg_check_errors) \
264
+ _(aten, _linalg_det) \
265
+ _(aten, _linalg_eigh) \
266
+ _(aten, _linalg_eigvals) \
267
+ _(aten, _linalg_slogdet) \
268
+ _(aten, _linalg_solve_ex) \
269
+ _(aten, _linalg_svd) \
270
+ _(aten, _local_scalar_dense) \
271
+ _(aten, _log_softmax) \
272
+ _(aten, _log_softmax_backward_data) \
273
+ _(aten, _logcumsumexp) \
274
+ _(aten, _lstm_mps) \
275
+ _(aten, _lu_with_info) \
276
+ _(aten, _make_dep_token) \
277
+ _(aten, _make_dual) \
278
+ _(aten, _make_dual_copy) \
279
+ _(aten, _make_per_channel_quantized_tensor) \
280
+ _(aten, _make_per_tensor_quantized_tensor) \
281
+ _(aten, _masked_scale) \
282
+ _(aten, _masked_softmax) \
283
+ _(aten, _masked_softmax_backward) \
284
+ _(aten, _mixed_dtypes_linear) \
285
+ _(aten, _mkldnn_reshape) \
286
+ _(aten, _mkldnn_transpose) \
287
+ _(aten, _mkldnn_transpose_) \
288
+ _(aten, _mps_convolution) \
289
+ _(aten, _mps_convolution_transpose) \
290
+ _(aten, _native_batch_norm_legit) \
291
+ _(aten, _native_batch_norm_legit_functional) \
292
+ _(aten, _native_batch_norm_legit_no_training) \
293
+ _(aten, _native_multi_head_attention) \
294
+ _(aten, _neg_view) \
295
+ _(aten, _neg_view_copy) \
296
+ _(aten, _nested_compute_contiguous_strides_offsets) \
297
+ _(aten, _nested_from_padded) \
298
+ _(aten, _nested_from_padded_and_nested_example) \
299
+ _(aten, _nested_from_padded_tensor) \
300
+ _(aten, _nested_get_jagged_dummy) \
301
+ _(aten, _nested_get_lengths) \
302
+ _(aten, _nested_get_max_seqlen) \
303
+ _(aten, _nested_get_min_seqlen) \
304
+ _(aten, _nested_get_offsets) \
305
+ _(aten, _nested_get_ragged_idx) \
306
+ _(aten, _nested_get_values) \
307
+ _(aten, _nested_get_values_copy) \
308
+ _(aten, _nested_select_backward) \
309
+ _(aten, _nested_sum_backward) \
310
+ _(aten, _nested_tensor_from_mask) \
311
+ _(aten, _nested_tensor_from_mask_left_aligned) \
312
+ _(aten, _nested_tensor_from_tensor_list) \
313
+ _(aten, _nested_tensor_size) \
314
+ _(aten, _nested_tensor_softmax_with_shape) \
315
+ _(aten, _nested_tensor_storage_offsets) \
316
+ _(aten, _nested_tensor_strides) \
317
+ _(aten, _nested_view_from_buffer) \
318
+ _(aten, _nested_view_from_buffer_copy) \
319
+ _(aten, _nested_view_from_jagged) \
320
+ _(aten, _nested_view_from_jagged_copy) \
321
+ _(aten, _new_zeros_with_same_feature_meta) \
322
+ _(aten, _nnpack_available) \
323
+ _(aten, _nnpack_spatial_convolution) \
324
+ _(aten, _nnz) \
325
+ _(aten, _pack_padded_sequence) \
326
+ _(aten, _pack_padded_sequence_backward) \
327
+ _(aten, _pad_circular) \
328
+ _(aten, _pad_enum) \
329
+ _(aten, _pad_packed_sequence) \
330
+ _(aten, _padded_dense_to_jagged_forward) \
331
+ _(aten, _pdist_backward) \
332
+ _(aten, _pdist_forward) \
333
+ _(aten, _pin_memory) \
334
+ _(aten, _prelu_kernel) \
335
+ _(aten, _prelu_kernel_backward) \
336
+ _(aten, _print) \
337
+ _(aten, _propagate_xla_data) \
338
+ _(aten, _remove_batch_dim) \
339
+ _(aten, _reshape_alias) \
340
+ _(aten, _reshape_alias_copy) \
341
+ _(aten, _reshape_copy) \
342
+ _(aten, _reshape_from_tensor) \
343
+ _(aten, _resize_output) \
344
+ _(aten, _resize_output_) \
345
+ _(aten, _rowwise_prune) \
346
+ _(aten, _safe_softmax) \
347
+ _(aten, _sample_dirichlet) \
348
+ _(aten, _saturate_weight_to_fp16) \
349
+ _(aten, _scaled_dot_product_attention_math) \
350
+ _(aten, _scaled_dot_product_attention_math_for_mps) \
351
+ _(aten, _scaled_dot_product_cudnn_attention) \
352
+ _(aten, _scaled_dot_product_cudnn_attention_backward) \
353
+ _(aten, _scaled_dot_product_efficient_attention) \
354
+ _(aten, _scaled_dot_product_efficient_attention_backward) \
355
+ _(aten, _scaled_dot_product_flash_attention) \
356
+ _(aten, _scaled_dot_product_flash_attention_backward) \
357
+ _(aten, _scaled_dot_product_flash_attention_for_cpu) \
358
+ _(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \
359
+ _(aten, _scaled_dot_product_fused_attention_overrideable) \
360
+ _(aten, _scaled_dot_product_fused_attention_overrideable_backward) \
361
+ _(aten, _scaled_grouped_mm) \
362
+ _(aten, _scaled_grouped_mm_v2) \
363
+ _(aten, _scaled_mm) \
364
+ _(aten, _scaled_mm_v2) \
365
+ _(aten, _segment_reduce_backward) \
366
+ _(aten, _shape_as_tensor) \
367
+ _(aten, _slow_conv2d_backward) \
368
+ _(aten, _slow_conv2d_forward) \
369
+ _(aten, _sobol_engine_draw) \
370
+ _(aten, _sobol_engine_ff) \
371
+ _(aten, _sobol_engine_ff_) \
372
+ _(aten, _sobol_engine_initialize_state) \
373
+ _(aten, _sobol_engine_initialize_state_) \
374
+ _(aten, _sobol_engine_scramble) \
375
+ _(aten, _sobol_engine_scramble_) \
376
+ _(aten, _softmax) \
377
+ _(aten, _softmax_backward_data) \
378
+ _(aten, _sparse_addmm) \
379
+ _(aten, _sparse_broadcast_to) \
380
+ _(aten, _sparse_broadcast_to_copy) \
381
+ _(aten, _sparse_bsc_tensor_unsafe) \
382
+ _(aten, _sparse_bsr_tensor_unsafe) \
383
+ _(aten, _sparse_compressed_tensor_unsafe) \
384
+ _(aten, _sparse_compressed_tensor_with_dims) \
385
+ _(aten, _sparse_coo_tensor_unsafe) \
386
+ _(aten, _sparse_coo_tensor_with_dims) \
387
+ _(aten, _sparse_coo_tensor_with_dims_and_tensors) \
388
+ _(aten, _sparse_csc_tensor_unsafe) \
389
+ _(aten, _sparse_csr_prod) \
390
+ _(aten, _sparse_csr_sum) \
391
+ _(aten, _sparse_csr_tensor_unsafe) \
392
+ _(aten, _sparse_log_softmax) \
393
+ _(aten, _sparse_log_softmax_backward_data) \
394
+ _(aten, _sparse_mask_projection) \
395
+ _(aten, _sparse_mm) \
396
+ _(aten, _sparse_mm_reduce_impl) \
397
+ _(aten, _sparse_mm_reduce_impl_backward) \
398
+ _(aten, _sparse_semi_structured_addmm) \
399
+ _(aten, _sparse_semi_structured_apply) \
400
+ _(aten, _sparse_semi_structured_apply_dense) \
401
+ _(aten, _sparse_semi_structured_linear) \
402
+ _(aten, _sparse_semi_structured_mm) \
403
+ _(aten, _sparse_semi_structured_tile) \
404
+ _(aten, _sparse_softmax) \
405
+ _(aten, _sparse_softmax_backward_data) \
406
+ _(aten, _sparse_sparse_matmul) \
407
+ _(aten, _sparse_sum) \
408
+ _(aten, _sparse_sum_backward) \
409
+ _(aten, _spdiags) \
410
+ _(aten, _spsolve) \
411
+ _(aten, _stack) \
412
+ _(aten, _standard_gamma) \
413
+ _(aten, _standard_gamma_grad) \
414
+ _(aten, _test_ambiguous_defaults) \
415
+ _(aten, _test_autograd_multiple_dispatch) \
416
+ _(aten, _test_autograd_multiple_dispatch_view) \
417
+ _(aten, _test_autograd_multiple_dispatch_view_copy) \
418
+ _(aten, _test_check_tensor) \
419
+ _(aten, _test_functorch_fallback) \
420
+ _(aten, _test_optional_filled_intlist) \
421
+ _(aten, _test_optional_floatlist) \
422
+ _(aten, _test_optional_intlist) \
423
+ _(aten, _test_parallel_materialize) \
424
+ _(aten, _test_serialization_subcmul) \
425
+ _(aten, _test_string_default) \
426
+ _(aten, _test_warn_in_autograd) \
427
+ _(aten, _thnn_differentiable_gru_cell_backward) \
428
+ _(aten, _thnn_differentiable_lstm_cell_backward) \
429
+ _(aten, _thnn_fused_gru_cell) \
430
+ _(aten, _thnn_fused_gru_cell_backward) \
431
+ _(aten, _thnn_fused_lstm_cell) \
432
+ _(aten, _thnn_fused_lstm_cell_backward) \
433
+ _(aten, _thnn_fused_lstm_cell_backward_impl) \
434
+ _(aten, _to_copy) \
435
+ _(aten, _to_cpu) \
436
+ _(aten, _to_dense) \
437
+ _(aten, _to_sparse) \
438
+ _(aten, _to_sparse_bsc) \
439
+ _(aten, _to_sparse_bsr) \
440
+ _(aten, _to_sparse_csc) \
441
+ _(aten, _to_sparse_csr) \
442
+ _(aten, _to_sparse_semi_structured) \
443
+ _(aten, _transform_bias_rescale_qkv) \
444
+ _(aten, _transformer_encoder_layer_fwd) \
445
+ _(aten, _trilinear) \
446
+ _(aten, _triton_multi_head_attention) \
447
+ _(aten, _triton_scaled_dot_attention) \
448
+ _(aten, _unique) \
449
+ _(aten, _unique2) \
450
+ _(aten, _unpack_dual) \
451
+ _(aten, _unsafe_index) \
452
+ _(aten, _unsafe_index_put) \
453
+ _(aten, _unsafe_masked_index) \
454
+ _(aten, _unsafe_masked_index_put_accumulate) \
455
+ _(aten, _unsafe_view) \
456
+ _(aten, _upsample_bicubic2d_aa) \
457
+ _(aten, _upsample_bicubic2d_aa_backward) \
458
+ _(aten, _upsample_bilinear2d_aa) \
459
+ _(aten, _upsample_bilinear2d_aa_backward) \
460
+ _(aten, _upsample_nearest_exact1d) \
461
+ _(aten, _upsample_nearest_exact1d_backward) \
462
+ _(aten, _upsample_nearest_exact2d) \
463
+ _(aten, _upsample_nearest_exact2d_backward) \
464
+ _(aten, _upsample_nearest_exact3d) \
465
+ _(aten, _upsample_nearest_exact3d_backward) \
466
+ _(aten, _use_cudnn_ctc_loss) \
467
+ _(aten, _use_cudnn_rnn_flatten_weight) \
468
+ _(aten, _validate_compressed_sparse_indices) \
469
+ _(aten, _validate_sparse_bsc_tensor_args) \
470
+ _(aten, _validate_sparse_bsr_tensor_args) \
471
+ _(aten, _validate_sparse_compressed_tensor_args) \
472
+ _(aten, _validate_sparse_coo_tensor_args) \
473
+ _(aten, _validate_sparse_csc_tensor_args) \
474
+ _(aten, _validate_sparse_csr_tensor_args) \
475
+ _(aten, _values) \
476
+ _(aten, _values_copy) \
477
+ _(aten, _version) \
478
+ _(aten, _weight_int4pack_mm) \
479
+ _(aten, _weight_int4pack_mm_for_cpu) \
480
+ _(aten, _weight_int4pack_mm_with_scales_and_zeros) \
481
+ _(aten, _weight_int8pack_mm) \
482
+ _(aten, _weight_norm) \
483
+ _(aten, _weight_norm_differentiable_backward) \
484
+ _(aten, _weight_norm_interface) \
485
+ _(aten, _weight_norm_interface_backward) \
486
+ _(aten, _wrapped_linear_prepack) \
487
+ _(aten, _wrapped_quantized_linear_prepacked) \
488
+ _(aten, abs) \
489
+ _(aten, abs_) \
490
+ _(aten, absolute) \
491
+ _(aten, absolute_) \
492
+ _(aten, acos) \
493
+ _(aten, acos_) \
494
+ _(aten, acosh) \
495
+ _(aten, acosh_) \
496
+ _(aten, adaptive_avg_pool1d) \
497
+ _(aten, adaptive_avg_pool2d) \
498
+ _(aten, adaptive_avg_pool3d) \
499
+ _(aten, adaptive_avg_pool3d_backward) \
500
+ _(aten, adaptive_max_pool1d) \
501
+ _(aten, adaptive_max_pool2d) \
502
+ _(aten, adaptive_max_pool2d_backward) \
503
+ _(aten, adaptive_max_pool3d) \
504
+ _(aten, adaptive_max_pool3d_backward) \
505
+ _(aten, add) \
506
+ _(aten, add_) \
507
+ _(aten, addbmm) \
508
+ _(aten, addbmm_) \
509
+ _(aten, addcdiv) \
510
+ _(aten, addcdiv_) \
511
+ _(aten, addcmul) \
512
+ _(aten, addcmul_) \
513
+ _(aten, addmm) \
514
+ _(aten, addmm_) \
515
+ _(aten, addmv) \
516
+ _(aten, addmv_) \
517
+ _(aten, addr) \
518
+ _(aten, addr_) \
519
+ _(aten, adjoint) \
520
+ _(aten, affine_grid_generator) \
521
+ _(aten, affine_grid_generator_backward) \
522
+ _(aten, alias) \
523
+ _(aten, alias_copy) \
524
+ _(aten, align_as) \
525
+ _(aten, align_tensors) \
526
+ _(aten, align_to) \
527
+ _(aten, all) \
528
+ _(aten, allclose) \
529
+ _(aten, alpha_dropout) \
530
+ _(aten, alpha_dropout_) \
531
+ _(aten, amax) \
532
+ _(aten, amin) \
533
+ _(aten, aminmax) \
534
+ _(aten, angle) \
535
+ _(aten, any) \
536
+ _(aten, arange) \
537
+ _(aten, arccos) \
538
+ _(aten, arccos_) \
539
+ _(aten, arccosh) \
540
+ _(aten, arccosh_) \
541
+ _(aten, arcsin) \
542
+ _(aten, arcsin_) \
543
+ _(aten, arcsinh) \
544
+ _(aten, arcsinh_) \
545
+ _(aten, arctan) \
546
+ _(aten, arctan2) \
547
+ _(aten, arctan2_) \
548
+ _(aten, arctan_) \
549
+ _(aten, arctanh) \
550
+ _(aten, arctanh_) \
551
+ _(aten, argmax) \
552
+ _(aten, argmin) \
553
+ _(aten, argsort) \
554
+ _(aten, argwhere) \
555
+ _(aten, as_strided) \
556
+ _(aten, as_strided_) \
557
+ _(aten, as_strided_copy) \
558
+ _(aten, as_strided_scatter) \
559
+ _(aten, asin) \
560
+ _(aten, asin_) \
561
+ _(aten, asinh) \
562
+ _(aten, asinh_) \
563
+ _(aten, atan) \
564
+ _(aten, atan2) \
565
+ _(aten, atan2_) \
566
+ _(aten, atan_) \
567
+ _(aten, atanh) \
568
+ _(aten, atanh_) \
569
+ _(aten, atleast_1d) \
570
+ _(aten, atleast_2d) \
571
+ _(aten, atleast_3d) \
572
+ _(aten, avg_pool1d) \
573
+ _(aten, avg_pool2d) \
574
+ _(aten, avg_pool2d_backward) \
575
+ _(aten, avg_pool3d) \
576
+ _(aten, avg_pool3d_backward) \
577
+ _(aten, baddbmm) \
578
+ _(aten, baddbmm_) \
579
+ _(aten, bartlett_window) \
580
+ _(aten, batch_norm) \
581
+ _(aten, batch_norm_backward) \
582
+ _(aten, batch_norm_backward_elemt) \
583
+ _(aten, batch_norm_backward_reduce) \
584
+ _(aten, batch_norm_elemt) \
585
+ _(aten, batch_norm_gather_stats) \
586
+ _(aten, batch_norm_gather_stats_with_counts) \
587
+ _(aten, batch_norm_stats) \
588
+ _(aten, batch_norm_update_stats) \
589
+ _(aten, bernoulli) \
590
+ _(aten, bernoulli_) \
591
+ _(aten, bilinear) \
592
+ _(aten, binary_cross_entropy) \
593
+ _(aten, binary_cross_entropy_backward) \
594
+ _(aten, binary_cross_entropy_with_logits) \
595
+ _(aten, bincount) \
596
+ _(aten, binomial) \
597
+ _(aten, bitwise_and) \
598
+ _(aten, bitwise_and_) \
599
+ _(aten, bitwise_left_shift) \
600
+ _(aten, bitwise_left_shift_) \
601
+ _(aten, bitwise_not) \
602
+ _(aten, bitwise_not_) \
603
+ _(aten, bitwise_or) \
604
+ _(aten, bitwise_or_) \
605
+ _(aten, bitwise_right_shift) \
606
+ _(aten, bitwise_right_shift_) \
607
+ _(aten, bitwise_xor) \
608
+ _(aten, bitwise_xor_) \
609
+ _(aten, blackman_window) \
610
+ _(aten, block_diag) \
611
+ _(aten, bmm) \
612
+ _(aten, broadcast_tensors) \
613
+ _(aten, broadcast_to) \
614
+ _(aten, bucketize) \
615
+ _(aten, can_cast) \
616
+ _(aten, cartesian_prod) \
617
+ _(aten, cat) \
618
+ _(aten, cauchy) \
619
+ _(aten, cauchy_) \
620
+ _(aten, ccol_indices) \
621
+ _(aten, ccol_indices_copy) \
622
+ _(aten, cdist) \
623
+ _(aten, ceil) \
624
+ _(aten, ceil_) \
625
+ _(aten, celu) \
626
+ _(aten, celu_) \
627
+ _(aten, chain_matmul) \
628
+ _(aten, chalf) \
629
+ _(aten, channel_shuffle) \
630
+ _(aten, cholesky) \
631
+ _(aten, cholesky_inverse) \
632
+ _(aten, cholesky_solve) \
633
+ _(aten, choose_qparams_optimized) \
634
+ _(aten, chunk) \
635
+ _(aten, clamp) \
636
+ _(aten, clamp_) \
637
+ _(aten, clamp_max) \
638
+ _(aten, clamp_max_) \
639
+ _(aten, clamp_min) \
640
+ _(aten, clamp_min_) \
641
+ _(aten, clip) \
642
+ _(aten, clip_) \
643
+ _(aten, clone) \
644
+ _(aten, coalesce) \
645
+ _(aten, col2im) \
646
+ _(aten, col_indices) \
647
+ _(aten, col_indices_copy) \
648
+ _(aten, column_stack) \
649
+ _(aten, combinations) \
650
+ _(aten, complex) \
651
+ _(aten, concat) \
652
+ _(aten, concatenate) \
653
+ _(aten, conj) \
654
+ _(aten, conj_physical) \
655
+ _(aten, conj_physical_) \
656
+ _(aten, constant_pad_nd) \
657
+ _(aten, contiguous) \
658
+ _(aten, conv1d) \
659
+ _(aten, conv2d) \
660
+ _(aten, conv3d) \
661
+ _(aten, conv_depthwise3d) \
662
+ _(aten, conv_tbc) \
663
+ _(aten, conv_tbc_backward) \
664
+ _(aten, conv_transpose1d) \
665
+ _(aten, conv_transpose2d) \
666
+ _(aten, conv_transpose3d) \
667
+ _(aten, convolution) \
668
+ _(aten, convolution_backward) \
669
+ _(aten, convolution_backward_overrideable) \
670
+ _(aten, convolution_overrideable) \
671
+ _(aten, copy) \
672
+ _(aten, copy_) \
673
+ _(aten, copy_sparse_to_sparse) \
674
+ _(aten, copy_sparse_to_sparse_) \
675
+ _(aten, copysign) \
676
+ _(aten, copysign_) \
677
+ _(aten, corrcoef) \
678
+ _(aten, cos) \
679
+ _(aten, cos_) \
680
+ _(aten, cosh) \
681
+ _(aten, cosh_) \
682
+ _(aten, cosine_embedding_loss) \
683
+ _(aten, cosine_similarity) \
684
+ _(aten, count_nonzero) \
685
+ _(aten, cov) \
686
+ _(aten, cross) \
687
+ _(aten, cross_entropy_loss) \
688
+ _(aten, crow_indices) \
689
+ _(aten, crow_indices_copy) \
690
+ _(aten, ctc_loss) \
691
+ _(aten, cudnn_affine_grid_generator) \
692
+ _(aten, cudnn_affine_grid_generator_backward) \
693
+ _(aten, cudnn_batch_norm) \
694
+ _(aten, cudnn_batch_norm_backward) \
695
+ _(aten, cudnn_convolution) \
696
+ _(aten, cudnn_convolution_add_relu) \
697
+ _(aten, cudnn_convolution_relu) \
698
+ _(aten, cudnn_convolution_transpose) \
699
+ _(aten, cudnn_grid_sampler) \
700
+ _(aten, cudnn_grid_sampler_backward) \
701
+ _(aten, cudnn_is_acceptable) \
702
+ _(aten, cummax) \
703
+ _(aten, cummaxmin_backward) \
704
+ _(aten, cummin) \
705
+ _(aten, cumprod) \
706
+ _(aten, cumprod_) \
707
+ _(aten, cumprod_backward) \
708
+ _(aten, cumsum) \
709
+ _(aten, cumsum_) \
710
+ _(aten, cumulative_trapezoid) \
711
+ _(aten, data) \
712
+ _(aten, deg2rad) \
713
+ _(aten, deg2rad_) \
714
+ _(aten, dense_dim) \
715
+ _(aten, dequantize) \
716
+ _(aten, det) \
717
+ _(aten, detach) \
718
+ _(aten, detach_) \
719
+ _(aten, detach_copy) \
720
+ _(aten, diag) \
721
+ _(aten, diag_embed) \
722
+ _(aten, diagflat) \
723
+ _(aten, diagonal) \
724
+ _(aten, diagonal_backward) \
725
+ _(aten, diagonal_copy) \
726
+ _(aten, diagonal_scatter) \
727
+ _(aten, diff) \
728
+ _(aten, digamma) \
729
+ _(aten, digamma_) \
730
+ _(aten, dist) \
731
+ _(aten, div) \
732
+ _(aten, div_) \
733
+ _(aten, divide) \
734
+ _(aten, divide_) \
735
+ _(aten, dot) \
736
+ _(aten, dropout) \
737
+ _(aten, dropout_) \
738
+ _(aten, dsplit) \
739
+ _(aten, dstack) \
740
+ _(aten, einsum) \
741
+ _(aten, elu) \
742
+ _(aten, elu_) \
743
+ _(aten, elu_backward) \
744
+ _(aten, embedding) \
745
+ _(aten, embedding_backward) \
746
+ _(aten, embedding_bag) \
747
+ _(aten, embedding_dense_backward) \
748
+ _(aten, embedding_renorm) \
749
+ _(aten, embedding_renorm_) \
750
+ _(aten, embedding_sparse_backward) \
751
+ _(aten, empty) \
752
+ _(aten, empty_like) \
753
+ _(aten, empty_permuted) \
754
+ _(aten, empty_quantized) \
755
+ _(aten, empty_strided) \
756
+ _(aten, eq) \
757
+ _(aten, eq_) \
758
+ _(aten, equal) \
759
+ _(aten, erf) \
760
+ _(aten, erf_) \
761
+ _(aten, erfc) \
762
+ _(aten, erfc_) \
763
+ _(aten, erfinv) \
764
+ _(aten, erfinv_) \
765
+ _(aten, exp) \
766
+ _(aten, exp2) \
767
+ _(aten, exp2_) \
768
+ _(aten, exp_) \
769
+ _(aten, expand) \
770
+ _(aten, expand_as) \
771
+ _(aten, expand_copy) \
772
+ _(aten, expm1) \
773
+ _(aten, expm1_) \
774
+ _(aten, exponential) \
775
+ _(aten, exponential_) \
776
+ _(aten, eye) \
777
+ _(aten, fake_quantize_per_channel_affine) \
778
+ _(aten, fake_quantize_per_channel_affine_cachemask) \
779
+ _(aten, fake_quantize_per_channel_affine_cachemask_backward) \
780
+ _(aten, fake_quantize_per_tensor_affine) \
781
+ _(aten, fake_quantize_per_tensor_affine_cachemask) \
782
+ _(aten, fake_quantize_per_tensor_affine_cachemask_backward) \
783
+ _(aten, fbgemm_linear_fp16_weight) \
784
+ _(aten, fbgemm_linear_fp16_weight_fp32_activation) \
785
+ _(aten, fbgemm_linear_int8_weight) \
786
+ _(aten, fbgemm_linear_int8_weight_fp32_activation) \
787
+ _(aten, fbgemm_linear_quantize_weight) \
788
+ _(aten, fbgemm_pack_gemm_matrix_fp16) \
789
+ _(aten, fbgemm_pack_quantized_matrix) \
790
+ _(aten, feature_alpha_dropout) \
791
+ _(aten, feature_alpha_dropout_) \
792
+ _(aten, feature_dropout) \
793
+ _(aten, feature_dropout_) \
794
+ _(aten, fft_fft) \
795
+ _(aten, fft_fft2) \
796
+ _(aten, fft_fftfreq) \
797
+ _(aten, fft_fftn) \
798
+ _(aten, fft_fftshift) \
799
+ _(aten, fft_hfft) \
800
+ _(aten, fft_hfft2) \
801
+ _(aten, fft_hfftn) \
802
+ _(aten, fft_ifft) \
803
+ _(aten, fft_ifft2) \
804
+ _(aten, fft_ifftn) \
805
+ _(aten, fft_ifftshift) \
806
+ _(aten, fft_ihfft) \
807
+ _(aten, fft_ihfft2) \
808
+ _(aten, fft_ihfftn) \
809
+ _(aten, fft_irfft) \
810
+ _(aten, fft_irfft2) \
811
+ _(aten, fft_irfftn) \
812
+ _(aten, fft_rfft) \
813
+ _(aten, fft_rfft2) \
814
+ _(aten, fft_rfftfreq) \
815
+ _(aten, fft_rfftn) \
816
+ _(aten, fill) \
817
+ _(aten, fill_) \
818
+ _(aten, fill_diagonal) \
819
+ _(aten, fill_diagonal_) \
820
+ _(aten, fix) \
821
+ _(aten, fix_) \
822
+ _(aten, flatten) \
823
+ _(aten, flatten_dense_tensors) \
824
+ _(aten, flip) \
825
+ _(aten, fliplr) \
826
+ _(aten, flipud) \
827
+ _(aten, float_power) \
828
+ _(aten, float_power_) \
829
+ _(aten, floor) \
830
+ _(aten, floor_) \
831
+ _(aten, floor_divide) \
832
+ _(aten, floor_divide_) \
833
+ _(aten, fmax) \
834
+ _(aten, fmin) \
835
+ _(aten, fmod) \
836
+ _(aten, fmod_) \
837
+ _(aten, frac) \
838
+ _(aten, frac_) \
839
+ _(aten, fractional_max_pool2d) \
840
+ _(aten, fractional_max_pool2d_backward) \
841
+ _(aten, fractional_max_pool3d) \
842
+ _(aten, fractional_max_pool3d_backward) \
843
+ _(aten, frexp) \
844
+ _(aten, frobenius_norm) \
845
+ _(aten, from_file) \
846
+ _(aten, full) \
847
+ _(aten, full_like) \
848
+ _(aten, fused_moving_avg_obs_fake_quant) \
849
+ _(aten, gather) \
850
+ _(aten, gather_backward) \
851
+ _(aten, gcd) \
852
+ _(aten, gcd_) \
853
+ _(aten, ge) \
854
+ _(aten, ge_) \
855
+ _(aten, gelu) \
856
+ _(aten, gelu_) \
857
+ _(aten, gelu_backward) \
858
+ _(aten, geometric) \
859
+ _(aten, geometric_) \
860
+ _(aten, geqrf) \
861
+ _(aten, ger) \
862
+ _(aten, glu) \
863
+ _(aten, glu_backward) \
864
+ _(aten, glu_backward_jvp) \
865
+ _(aten, glu_jvp) \
866
+ _(aten, gradient) \
867
+ _(aten, greater) \
868
+ _(aten, greater_) \
869
+ _(aten, greater_equal) \
870
+ _(aten, greater_equal_) \
871
+ _(aten, grid_sampler) \
872
+ _(aten, grid_sampler_2d) \
873
+ _(aten, grid_sampler_2d_backward) \
874
+ _(aten, grid_sampler_3d) \
875
+ _(aten, grid_sampler_3d_backward) \
876
+ _(aten, group_norm) \
877
+ _(aten, gru) \
878
+ _(aten, gru_cell) \
879
+ _(aten, gt) \
880
+ _(aten, gt_) \
881
+ _(aten, hamming_window) \
882
+ _(aten, hann_window) \
883
+ _(aten, hardshrink) \
884
+ _(aten, hardshrink_backward) \
885
+ _(aten, hardsigmoid) \
886
+ _(aten, hardsigmoid_) \
887
+ _(aten, hardsigmoid_backward) \
888
+ _(aten, hardswish) \
889
+ _(aten, hardswish_) \
890
+ _(aten, hardswish_backward) \
891
+ _(aten, hardtanh) \
892
+ _(aten, hardtanh_) \
893
+ _(aten, hardtanh_backward) \
894
+ _(aten, hash_tensor) \
895
+ _(aten, heaviside) \
896
+ _(aten, heaviside_) \
897
+ _(aten, hinge_embedding_loss) \
898
+ _(aten, histc) \
899
+ _(aten, histogram) \
900
+ _(aten, histogramdd) \
901
+ _(aten, hsplit) \
902
+ _(aten, hspmm) \
903
+ _(aten, hstack) \
904
+ _(aten, huber_loss) \
905
+ _(aten, huber_loss_backward) \
906
+ _(aten, hypot) \
907
+ _(aten, hypot_) \
908
+ _(aten, i0) \
909
+ _(aten, i0_) \
910
+ _(aten, igamma) \
911
+ _(aten, igamma_) \
912
+ _(aten, igammac) \
913
+ _(aten, igammac_) \
914
+ _(aten, im2col) \
915
+ _(aten, imag) \
916
+ _(aten, index) \
917
+ _(aten, index_add) \
918
+ _(aten, index_add_) \
919
+ _(aten, index_copy) \
920
+ _(aten, index_copy_) \
921
+ _(aten, index_fill) \
922
+ _(aten, index_fill_) \
923
+ _(aten, index_put) \
924
+ _(aten, index_put_) \
925
+ _(aten, index_reduce) \
926
+ _(aten, index_reduce_) \
927
+ _(aten, index_select) \
928
+ _(aten, index_select_backward) \
929
+ _(aten, indices) \
930
+ _(aten, indices_copy) \
931
+ _(aten, infinitely_differentiable_gelu_backward) \
932
+ _(aten, inner) \
933
+ _(aten, instance_norm) \
934
+ _(aten, int_repr) \
935
+ _(aten, inverse) \
936
+ _(aten, is_coalesced) \
937
+ _(aten, is_complex) \
938
+ _(aten, is_conj) \
939
+ _(aten, is_distributed) \
940
+ _(aten, is_floating_point) \
941
+ _(aten, is_inference) \
942
+ _(aten, is_leaf) \
943
+ _(aten, is_neg) \
944
+ _(aten, is_nonzero) \
945
+ _(aten, is_pinned) \
946
+ _(aten, is_same_size) \
947
+ _(aten, is_set_to) \
948
+ _(aten, is_signed) \
949
+ _(aten, is_vulkan_available) \
950
+ _(aten, isclose) \
951
+ _(aten, isfinite) \
952
+ _(aten, isin) \
953
+ _(aten, isinf) \
954
+ _(aten, isnan) \
955
+ _(aten, isneginf) \
956
+ _(aten, isposinf) \
957
+ _(aten, isreal) \
958
+ _(aten, istft) \
959
+ _(aten, item) \
960
+ _(aten, kaiser_window) \
961
+ _(aten, kl_div) \
962
+ _(aten, kron) \
963
+ _(aten, kthvalue) \
964
+ _(aten, l1_loss) \
965
+ _(aten, layer_norm) \
966
+ _(aten, lcm) \
967
+ _(aten, lcm_) \
968
+ _(aten, ldexp) \
969
+ _(aten, ldexp_) \
970
+ _(aten, le) \
971
+ _(aten, le_) \
972
+ _(aten, leaky_relu) \
973
+ _(aten, leaky_relu_) \
974
+ _(aten, leaky_relu_backward) \
975
+ _(aten, lerp) \
976
+ _(aten, lerp_) \
977
+ _(aten, less) \
978
+ _(aten, less_) \
979
+ _(aten, less_equal) \
980
+ _(aten, less_equal_) \
981
+ _(aten, lgamma) \
982
+ _(aten, lgamma_) \
983
+ _(aten, lift) \
984
+ _(aten, lift_fresh) \
985
+ _(aten, lift_fresh_copy) \
986
+ _(aten, linalg_cholesky) \
987
+ _(aten, linalg_cholesky_ex) \
988
+ _(aten, linalg_cond) \
989
+ _(aten, linalg_cross) \
990
+ _(aten, linalg_det) \
991
+ _(aten, linalg_diagonal) \
992
+ _(aten, linalg_eig) \
993
+ _(aten, linalg_eigh) \
994
+ _(aten, linalg_eigvals) \
995
+ _(aten, linalg_eigvalsh) \
996
+ _(aten, linalg_householder_product) \
997
+ _(aten, linalg_inv) \
998
+ _(aten, linalg_inv_ex) \
999
+ _(aten, linalg_ldl_factor) \
1000
+ _(aten, linalg_ldl_factor_ex) \
1001
+ _(aten, linalg_ldl_solve) \
1002
+ _(aten, linalg_lstsq) \
1003
+ _(aten, linalg_lu) \
1004
+ _(aten, linalg_lu_factor) \
1005
+ _(aten, linalg_lu_factor_ex) \
1006
+ _(aten, linalg_lu_solve) \
1007
+ _(aten, linalg_matmul) \
1008
+ _(aten, linalg_matrix_exp) \
1009
+ _(aten, linalg_matrix_norm) \
1010
+ _(aten, linalg_matrix_power) \
1011
+ _(aten, linalg_matrix_rank) \
1012
+ _(aten, linalg_multi_dot) \
1013
+ _(aten, linalg_norm) \
1014
+ _(aten, linalg_pinv) \
1015
+ _(aten, linalg_qr) \
1016
+ _(aten, linalg_slogdet) \
1017
+ _(aten, linalg_solve) \
1018
+ _(aten, linalg_solve_ex) \
1019
+ _(aten, linalg_solve_triangular) \
1020
+ _(aten, linalg_svd) \
1021
+ _(aten, linalg_svdvals) \
1022
+ _(aten, linalg_tensorinv) \
1023
+ _(aten, linalg_tensorsolve) \
1024
+ _(aten, linalg_vander) \
1025
+ _(aten, linalg_vecdot) \
1026
+ _(aten, linalg_vector_norm) \
1027
+ _(aten, linear) \
1028
+ _(aten, linear_backward) \
1029
+ _(aten, linspace) \
1030
+ _(aten, log) \
1031
+ _(aten, log10) \
1032
+ _(aten, log10_) \
1033
+ _(aten, log1p) \
1034
+ _(aten, log1p_) \
1035
+ _(aten, log2) \
1036
+ _(aten, log2_) \
1037
+ _(aten, log_) \
1038
+ _(aten, log_normal) \
1039
+ _(aten, log_normal_) \
1040
+ _(aten, log_sigmoid) \
1041
+ _(aten, log_sigmoid_backward) \
1042
+ _(aten, log_sigmoid_forward) \
1043
+ _(aten, log_softmax) \
1044
+ _(aten, logaddexp) \
1045
+ _(aten, logaddexp2) \
1046
+ _(aten, logcumsumexp) \
1047
+ _(aten, logdet) \
1048
+ _(aten, logical_and) \
1049
+ _(aten, logical_and_) \
1050
+ _(aten, logical_not) \
1051
+ _(aten, logical_not_) \
1052
+ _(aten, logical_or) \
1053
+ _(aten, logical_or_) \
1054
+ _(aten, logical_xor) \
1055
+ _(aten, logical_xor_) \
1056
+ _(aten, logit) \
1057
+ _(aten, logit_) \
1058
+ _(aten, logit_backward) \
1059
+ _(aten, logspace) \
1060
+ _(aten, logsumexp) \
1061
+ _(aten, lshift) \
1062
+ _(aten, lstm) \
1063
+ _(aten, lstm_cell) \
1064
+ _(aten, lstm_mps_backward) \
1065
+ _(aten, lt) \
1066
+ _(aten, lt_) \
1067
+ _(aten, lu_solve) \
1068
+ _(aten, lu_unpack) \
1069
+ _(aten, mH) \
1070
+ _(aten, mT) \
1071
+ _(aten, margin_ranking_loss) \
1072
+ _(aten, masked_fill) \
1073
+ _(aten, masked_fill_) \
1074
+ _(aten, masked_scatter) \
1075
+ _(aten, masked_scatter_) \
1076
+ _(aten, masked_scatter_backward) \
1077
+ _(aten, masked_select) \
1078
+ _(aten, masked_select_backward) \
1079
+ _(aten, matmul) \
1080
+ _(aten, matmul_backward) \
1081
+ _(aten, matrix_H) \
1082
+ _(aten, matrix_exp) \
1083
+ _(aten, matrix_exp_backward) \
1084
+ _(aten, matrix_power) \
1085
+ _(aten, max) \
1086
+ _(aten, max_pool1d) \
1087
+ _(aten, max_pool1d_with_indices) \
1088
+ _(aten, max_pool2d) \
1089
+ _(aten, max_pool2d_backward) \
1090
+ _(aten, max_pool2d_with_indices) \
1091
+ _(aten, max_pool2d_with_indices_backward) \
1092
+ _(aten, max_pool3d) \
1093
+ _(aten, max_pool3d_with_indices) \
1094
+ _(aten, max_pool3d_with_indices_backward) \
1095
+ _(aten, max_unpool2d) \
1096
+ _(aten, max_unpool3d) \
1097
+ _(aten, maximum) \
1098
+ _(aten, mean) \
1099
+ _(aten, median) \
1100
+ _(aten, meshgrid) \
1101
+ _(aten, min) \
1102
+ _(aten, minimum) \
1103
+ _(aten, miopen_batch_norm) \
1104
+ _(aten, miopen_batch_norm_backward) \
1105
+ _(aten, miopen_convolution) \
1106
+ _(aten, miopen_convolution_add_relu) \
1107
+ _(aten, miopen_convolution_relu) \
1108
+ _(aten, miopen_convolution_transpose) \
1109
+ _(aten, miopen_depthwise_convolution) \
1110
+ _(aten, miopen_rnn) \
1111
+ _(aten, miopen_rnn_backward) \
1112
+ _(aten, mish) \
1113
+ _(aten, mish_) \
1114
+ _(aten, mish_backward) \
1115
+ _(aten, mkldnn_adaptive_avg_pool2d) \
1116
+ _(aten, mkldnn_adaptive_avg_pool2d_backward) \
1117
+ _(aten, mkldnn_convolution) \
1118
+ _(aten, mkldnn_linear) \
1119
+ _(aten, mkldnn_linear_backward) \
1120
+ _(aten, mkldnn_linear_backward_input) \
1121
+ _(aten, mkldnn_linear_backward_weights) \
1122
+ _(aten, mkldnn_max_pool2d) \
1123
+ _(aten, mkldnn_max_pool2d_backward) \
1124
+ _(aten, mkldnn_max_pool3d) \
1125
+ _(aten, mkldnn_max_pool3d_backward) \
1126
+ _(aten, mkldnn_reorder_conv2d_weight) \
1127
+ _(aten, mkldnn_reorder_conv3d_weight) \
1128
+ _(aten, mkldnn_rnn_layer) \
1129
+ _(aten, mkldnn_rnn_layer_backward) \
1130
+ _(aten, mm) \
1131
+ _(aten, mode) \
1132
+ _(aten, moveaxis) \
1133
+ _(aten, movedim) \
1134
+ _(aten, mps_convolution_backward) \
1135
+ _(aten, mps_convolution_transpose_backward) \
1136
+ _(aten, mse_loss) \
1137
+ _(aten, mse_loss_backward) \
1138
+ _(aten, msort) \
1139
+ _(aten, mul) \
1140
+ _(aten, mul_) \
1141
+ _(aten, multi_margin_loss) \
1142
+ _(aten, multi_margin_loss_backward) \
1143
+ _(aten, multilabel_margin_loss) \
1144
+ _(aten, multilabel_margin_loss_backward) \
1145
+ _(aten, multilabel_margin_loss_forward) \
1146
+ _(aten, multinomial) \
1147
+ _(aten, multiply) \
1148
+ _(aten, multiply_) \
1149
+ _(aten, mv) \
1150
+ _(aten, mvlgamma) \
1151
+ _(aten, mvlgamma_) \
1152
+ _(aten, nan_to_num) \
1153
+ _(aten, nan_to_num_) \
1154
+ _(aten, nanmean) \
1155
+ _(aten, nanmedian) \
1156
+ _(aten, nanquantile) \
1157
+ _(aten, nansum) \
1158
+ _(aten, narrow) \
1159
+ _(aten, narrow_copy) \
1160
+ _(aten, native_batch_norm) \
1161
+ _(aten, native_batch_norm_backward) \
1162
+ _(aten, native_channel_shuffle) \
1163
+ _(aten, native_dropout) \
1164
+ _(aten, native_dropout_backward) \
1165
+ _(aten, native_group_norm) \
1166
+ _(aten, native_group_norm_backward) \
1167
+ _(aten, native_layer_norm) \
1168
+ _(aten, native_layer_norm_backward) \
1169
+ _(aten, native_norm) \
1170
+ _(aten, ne) \
1171
+ _(aten, ne_) \
1172
+ _(aten, neg) \
1173
+ _(aten, neg_) \
1174
+ _(aten, negative) \
1175
+ _(aten, negative_) \
1176
+ _(aten, nested_to_padded_tensor) \
1177
+ _(aten, new_empty) \
1178
+ _(aten, new_empty_strided) \
1179
+ _(aten, new_full) \
1180
+ _(aten, new_ones) \
1181
+ _(aten, new_zeros) \
1182
+ _(aten, nextafter) \
1183
+ _(aten, nextafter_) \
1184
+ _(aten, nll_loss) \
1185
+ _(aten, nll_loss2d) \
1186
+ _(aten, nll_loss2d_backward) \
1187
+ _(aten, nll_loss2d_forward) \
1188
+ _(aten, nll_loss_backward) \
1189
+ _(aten, nll_loss_forward) \
1190
+ _(aten, nll_loss_nd) \
1191
+ _(aten, nonzero) \
1192
+ _(aten, nonzero_numpy) \
1193
+ _(aten, nonzero_static) \
1194
+ _(aten, norm) \
1195
+ _(aten, norm_except_dim) \
1196
+ _(aten, normal) \
1197
+ _(aten, normal_) \
1198
+ _(aten, normal_functional) \
1199
+ _(aten, not_equal) \
1200
+ _(aten, not_equal_) \
1201
+ _(aten, nuclear_norm) \
1202
+ _(aten, numpy_T) \
1203
+ _(aten, one_hot) \
1204
+ _(aten, ones) \
1205
+ _(aten, ones_like) \
1206
+ _(aten, orgqr) \
1207
+ _(aten, ormqr) \
1208
+ _(aten, outer) \
1209
+ _(aten, output_nr) \
1210
+ _(aten, pad) \
1211
+ _(aten, pad_sequence) \
1212
+ _(aten, pairwise_distance) \
1213
+ _(aten, pdist) \
1214
+ _(aten, permute) \
1215
+ _(aten, permute_copy) \
1216
+ _(aten, pin_memory) \
1217
+ _(aten, pinverse) \
1218
+ _(aten, pixel_shuffle) \
1219
+ _(aten, pixel_unshuffle) \
1220
+ _(aten, poisson) \
1221
+ _(aten, poisson_nll_loss) \
1222
+ _(aten, polar) \
1223
+ _(aten, polygamma) \
1224
+ _(aten, polygamma_) \
1225
+ _(aten, positive) \
1226
+ _(aten, pow) \
1227
+ _(aten, pow_) \
1228
+ _(aten, prelu) \
1229
+ _(aten, prod) \
1230
+ _(aten, promote_types) \
1231
+ _(aten, put) \
1232
+ _(aten, put_) \
1233
+ _(aten, q_per_channel_axis) \
1234
+ _(aten, q_per_channel_scales) \
1235
+ _(aten, q_per_channel_zero_points) \
1236
+ _(aten, q_scale) \
1237
+ _(aten, q_zero_point) \
1238
+ _(aten, qr) \
1239
+ _(aten, qscheme) \
1240
+ _(aten, quantile) \
1241
+ _(aten, quantize_per_channel) \
1242
+ _(aten, quantize_per_tensor) \
1243
+ _(aten, quantize_per_tensor_dynamic) \
1244
+ _(aten, quantized_batch_norm) \
1245
+ _(aten, quantized_gru_cell) \
1246
+ _(aten, quantized_lstm_cell) \
1247
+ _(aten, quantized_max_pool1d) \
1248
+ _(aten, quantized_max_pool2d) \
1249
+ _(aten, quantized_max_pool3d) \
1250
+ _(aten, quantized_rnn_relu_cell) \
1251
+ _(aten, quantized_rnn_tanh_cell) \
1252
+ _(aten, rad2deg) \
1253
+ _(aten, rad2deg_) \
1254
+ _(aten, rand) \
1255
+ _(aten, rand_like) \
1256
+ _(aten, randint) \
1257
+ _(aten, randint_like) \
1258
+ _(aten, randn) \
1259
+ _(aten, randn_like) \
1260
+ _(aten, random) \
1261
+ _(aten, random_) \
1262
+ _(aten, randperm) \
1263
+ _(aten, range) \
1264
+ _(aten, ravel) \
1265
+ _(aten, real) \
1266
+ _(aten, reciprocal) \
1267
+ _(aten, reciprocal_) \
1268
+ _(aten, record_stream) \
1269
+ _(aten, refine_names) \
1270
+ _(aten, reflection_pad1d) \
1271
+ _(aten, reflection_pad1d_backward) \
1272
+ _(aten, reflection_pad2d) \
1273
+ _(aten, reflection_pad2d_backward) \
1274
+ _(aten, reflection_pad3d) \
1275
+ _(aten, reflection_pad3d_backward) \
1276
+ _(aten, relu) \
1277
+ _(aten, relu6) \
1278
+ _(aten, relu6_) \
1279
+ _(aten, relu_) \
1280
+ _(aten, remainder) \
1281
+ _(aten, remainder_) \
1282
+ _(aten, rename) \
1283
+ _(aten, rename_) \
1284
+ _(aten, renorm) \
1285
+ _(aten, renorm_) \
1286
+ _(aten, repeat) \
1287
+ _(aten, repeat_interleave) \
1288
+ _(aten, replication_pad1d) \
1289
+ _(aten, replication_pad1d_backward) \
1290
+ _(aten, replication_pad2d) \
1291
+ _(aten, replication_pad2d_backward) \
1292
+ _(aten, replication_pad3d) \
1293
+ _(aten, replication_pad3d_backward) \
1294
+ _(aten, requires_grad) \
1295
+ _(aten, requires_grad_) \
1296
+ _(aten, reshape) \
1297
+ _(aten, reshape_as) \
1298
+ _(aten, resize) \
1299
+ _(aten, resize_) \
1300
+ _(aten, resize_as) \
1301
+ _(aten, resize_as_) \
1302
+ _(aten, resize_as_sparse) \
1303
+ _(aten, resize_as_sparse_) \
1304
+ _(aten, resolve_conj) \
1305
+ _(aten, resolve_neg) \
1306
+ _(aten, result_type) \
1307
+ _(aten, retain_grad) \
1308
+ _(aten, retains_grad) \
1309
+ _(aten, rms_norm) \
1310
+ _(aten, rnn_relu) \
1311
+ _(aten, rnn_relu_cell) \
1312
+ _(aten, rnn_tanh) \
1313
+ _(aten, rnn_tanh_cell) \
1314
+ _(aten, roll) \
1315
+ _(aten, rot90) \
1316
+ _(aten, round) \
1317
+ _(aten, round_) \
1318
+ _(aten, row_indices) \
1319
+ _(aten, row_indices_copy) \
1320
+ _(aten, row_stack) \
1321
+ _(aten, rrelu) \
1322
+ _(aten, rrelu_) \
1323
+ _(aten, rrelu_with_noise) \
1324
+ _(aten, rrelu_with_noise_) \
1325
+ _(aten, rrelu_with_noise_backward) \
1326
+ _(aten, rrelu_with_noise_functional) \
1327
+ _(aten, rshift) \
1328
+ _(aten, rsqrt) \
1329
+ _(aten, rsqrt_) \
1330
+ _(aten, rsub) \
1331
+ _(aten, scalar_tensor) \
1332
+ _(aten, scaled_dot_product_attention) \
1333
+ _(aten, scatter) \
1334
+ _(aten, scatter_) \
1335
+ _(aten, scatter_add) \
1336
+ _(aten, scatter_add_) \
1337
+ _(aten, scatter_reduce) \
1338
+ _(aten, scatter_reduce_) \
1339
+ _(aten, searchsorted) \
1340
+ _(aten, segment_reduce) \
1341
+ _(aten, select) \
1342
+ _(aten, select_backward) \
1343
+ _(aten, select_copy) \
1344
+ _(aten, select_scatter) \
1345
+ _(aten, selu) \
1346
+ _(aten, selu_) \
1347
+ _(aten, set) \
1348
+ _(aten, set_) \
1349
+ _(aten, set_data) \
1350
+ _(aten, sgn) \
1351
+ _(aten, sgn_) \
1352
+ _(aten, sigmoid) \
1353
+ _(aten, sigmoid_) \
1354
+ _(aten, sigmoid_backward) \
1355
+ _(aten, sign) \
1356
+ _(aten, sign_) \
1357
+ _(aten, signbit) \
1358
+ _(aten, silu) \
1359
+ _(aten, silu_) \
1360
+ _(aten, silu_backward) \
1361
+ _(aten, sin) \
1362
+ _(aten, sin_) \
1363
+ _(aten, sinc) \
1364
+ _(aten, sinc_) \
1365
+ _(aten, sinh) \
1366
+ _(aten, sinh_) \
1367
+ _(aten, size) \
1368
+ _(aten, slice) \
1369
+ _(aten, slice_backward) \
1370
+ _(aten, slice_copy) \
1371
+ _(aten, slice_inverse) \
1372
+ _(aten, slice_scatter) \
1373
+ _(aten, slogdet) \
1374
+ _(aten, slow_conv3d) \
1375
+ _(aten, slow_conv3d_forward) \
1376
+ _(aten, slow_conv_dilated2d) \
1377
+ _(aten, slow_conv_dilated3d) \
1378
+ _(aten, slow_conv_transpose2d) \
1379
+ _(aten, slow_conv_transpose3d) \
1380
+ _(aten, smm) \
1381
+ _(aten, smooth_l1_loss) \
1382
+ _(aten, smooth_l1_loss_backward) \
1383
+ _(aten, soft_margin_loss) \
1384
+ _(aten, soft_margin_loss_backward) \
1385
+ _(aten, softmax) \
1386
+ _(aten, softplus) \
1387
+ _(aten, softplus_backward) \
1388
+ _(aten, softshrink) \
1389
+ _(aten, softshrink_backward) \
1390
+ _(aten, sort) \
1391
+ _(aten, sparse_bsc_tensor) \
1392
+ _(aten, sparse_bsr_tensor) \
1393
+ _(aten, sparse_compressed_tensor) \
1394
+ _(aten, sparse_coo_tensor) \
1395
+ _(aten, sparse_csc_tensor) \
1396
+ _(aten, sparse_csr_tensor) \
1397
+ _(aten, sparse_dim) \
1398
+ _(aten, sparse_mask) \
1399
+ _(aten, sparse_resize) \
1400
+ _(aten, sparse_resize_) \
1401
+ _(aten, sparse_resize_and_clear) \
1402
+ _(aten, sparse_resize_and_clear_) \
1403
+ _(aten, sparse_sampled_addmm) \
1404
+ _(aten, special_airy_ai) \
1405
+ _(aten, special_bessel_j0) \
1406
+ _(aten, special_bessel_j1) \
1407
+ _(aten, special_bessel_y0) \
1408
+ _(aten, special_bessel_y1) \
1409
+ _(aten, special_chebyshev_polynomial_t) \
1410
+ _(aten, special_chebyshev_polynomial_u) \
1411
+ _(aten, special_chebyshev_polynomial_v) \
1412
+ _(aten, special_chebyshev_polynomial_w) \
1413
+ _(aten, special_digamma) \
1414
+ _(aten, special_entr) \
1415
+ _(aten, special_erf) \
1416
+ _(aten, special_erfc) \
1417
+ _(aten, special_erfcx) \
1418
+ _(aten, special_erfinv) \
1419
+ _(aten, special_exp2) \
1420
+ _(aten, special_expit) \
1421
+ _(aten, special_expm1) \
1422
+ _(aten, special_gammainc) \
1423
+ _(aten, special_gammaincc) \
1424
+ _(aten, special_gammaln) \
1425
+ _(aten, special_hermite_polynomial_h) \
1426
+ _(aten, special_hermite_polynomial_he) \
1427
+ _(aten, special_i0) \
1428
+ _(aten, special_i0e) \
1429
+ _(aten, special_i1) \
1430
+ _(aten, special_i1e) \
1431
+ _(aten, special_laguerre_polynomial_l) \
1432
+ _(aten, special_legendre_polynomial_p) \
1433
+ _(aten, special_log1p) \
1434
+ _(aten, special_log_ndtr) \
1435
+ _(aten, special_log_softmax) \
1436
+ _(aten, special_logit) \
1437
+ _(aten, special_logsumexp) \
1438
+ _(aten, special_modified_bessel_i0) \
1439
+ _(aten, special_modified_bessel_i1) \
1440
+ _(aten, special_modified_bessel_k0) \
1441
+ _(aten, special_modified_bessel_k1) \
1442
+ _(aten, special_multigammaln) \
1443
+ _(aten, special_ndtr) \
1444
+ _(aten, special_ndtri) \
1445
+ _(aten, special_polygamma) \
1446
+ _(aten, special_psi) \
1447
+ _(aten, special_round) \
1448
+ _(aten, special_scaled_modified_bessel_k0) \
1449
+ _(aten, special_scaled_modified_bessel_k1) \
1450
+ _(aten, special_shifted_chebyshev_polynomial_t) \
1451
+ _(aten, special_shifted_chebyshev_polynomial_u) \
1452
+ _(aten, special_shifted_chebyshev_polynomial_v) \
1453
+ _(aten, special_shifted_chebyshev_polynomial_w) \
1454
+ _(aten, special_sinc) \
1455
+ _(aten, special_softmax) \
1456
+ _(aten, special_spherical_bessel_j0) \
1457
+ _(aten, special_xlog1py) \
1458
+ _(aten, special_xlogy) \
1459
+ _(aten, special_zeta) \
1460
+ _(aten, split) \
1461
+ _(aten, split_copy) \
1462
+ _(aten, split_with_sizes) \
1463
+ _(aten, split_with_sizes_copy) \
1464
+ _(aten, sqrt) \
1465
+ _(aten, sqrt_) \
1466
+ _(aten, square) \
1467
+ _(aten, square_) \
1468
+ _(aten, squeeze) \
1469
+ _(aten, squeeze_) \
1470
+ _(aten, squeeze_copy) \
1471
+ _(aten, sspaddmm) \
1472
+ _(aten, stack) \
1473
+ _(aten, std) \
1474
+ _(aten, std_mean) \
1475
+ _(aten, stft) \
1476
+ _(aten, stride) \
1477
+ _(aten, sub) \
1478
+ _(aten, sub_) \
1479
+ _(aten, subtract) \
1480
+ _(aten, subtract_) \
1481
+ _(aten, sum) \
1482
+ _(aten, sum_to_size) \
1483
+ _(aten, svd) \
1484
+ _(aten, swapaxes) \
1485
+ _(aten, swapaxes_) \
1486
+ _(aten, swapdims) \
1487
+ _(aten, swapdims_) \
1488
+ _(aten, sym_constrain_range) \
1489
+ _(aten, sym_constrain_range_for_size) \
1490
+ _(aten, sym_is_contiguous) \
1491
+ _(aten, sym_numel) \
1492
+ _(aten, sym_size) \
1493
+ _(aten, sym_storage_offset) \
1494
+ _(aten, sym_stride) \
1495
+ _(aten, t) \
1496
+ _(aten, t_) \
1497
+ _(aten, t_copy) \
1498
+ _(aten, take) \
1499
+ _(aten, take_along_dim) \
1500
+ _(aten, tan) \
1501
+ _(aten, tan_) \
1502
+ _(aten, tanh) \
1503
+ _(aten, tanh_) \
1504
+ _(aten, tanh_backward) \
1505
+ _(aten, tensor_split) \
1506
+ _(aten, tensordot) \
1507
+ _(aten, thnn_conv2d) \
1508
+ _(aten, threshold) \
1509
+ _(aten, threshold_) \
1510
+ _(aten, threshold_backward) \
1511
+ _(aten, tile) \
1512
+ _(aten, to) \
1513
+ _(aten, to_dense) \
1514
+ _(aten, to_dense_backward) \
1515
+ _(aten, to_mkldnn) \
1516
+ _(aten, to_mkldnn_backward) \
1517
+ _(aten, to_padded_tensor) \
1518
+ _(aten, to_sparse) \
1519
+ _(aten, to_sparse_bsc) \
1520
+ _(aten, to_sparse_bsr) \
1521
+ _(aten, to_sparse_csc) \
1522
+ _(aten, to_sparse_csr) \
1523
+ _(aten, topk) \
1524
+ _(aten, trace) \
1525
+ _(aten, trace_backward) \
1526
+ _(aten, transpose) \
1527
+ _(aten, transpose_) \
1528
+ _(aten, transpose_copy) \
1529
+ _(aten, trapezoid) \
1530
+ _(aten, trapz) \
1531
+ _(aten, triangular_solve) \
1532
+ _(aten, tril) \
1533
+ _(aten, tril_) \
1534
+ _(aten, tril_indices) \
1535
+ _(aten, triplet_margin_loss) \
1536
+ _(aten, triu) \
1537
+ _(aten, triu_) \
1538
+ _(aten, triu_indices) \
1539
+ _(aten, true_divide) \
1540
+ _(aten, true_divide_) \
1541
+ _(aten, trunc) \
1542
+ _(aten, trunc_) \
1543
+ _(aten, type_as) \
1544
+ _(aten, unbind) \
1545
+ _(aten, unbind_copy) \
1546
+ _(aten, unflatten) \
1547
+ _(aten, unflatten_dense_tensors) \
1548
+ _(aten, unfold) \
1549
+ _(aten, unfold_backward) \
1550
+ _(aten, unfold_copy) \
1551
+ _(aten, uniform) \
1552
+ _(aten, uniform_) \
1553
+ _(aten, unique_consecutive) \
1554
+ _(aten, unique_dim) \
1555
+ _(aten, unique_dim_consecutive) \
1556
+ _(aten, unsafe_chunk) \
1557
+ _(aten, unsafe_split) \
1558
+ _(aten, unsafe_split_with_sizes) \
1559
+ _(aten, unsqueeze) \
1560
+ _(aten, unsqueeze_) \
1561
+ _(aten, unsqueeze_copy) \
1562
+ _(aten, upsample_bicubic2d) \
1563
+ _(aten, upsample_bicubic2d_backward) \
1564
+ _(aten, upsample_bilinear2d) \
1565
+ _(aten, upsample_bilinear2d_backward) \
1566
+ _(aten, upsample_linear1d) \
1567
+ _(aten, upsample_linear1d_backward) \
1568
+ _(aten, upsample_nearest1d) \
1569
+ _(aten, upsample_nearest1d_backward) \
1570
+ _(aten, upsample_nearest2d) \
1571
+ _(aten, upsample_nearest2d_backward) \
1572
+ _(aten, upsample_nearest3d) \
1573
+ _(aten, upsample_nearest3d_backward) \
1574
+ _(aten, upsample_trilinear3d) \
1575
+ _(aten, upsample_trilinear3d_backward) \
1576
+ _(aten, value_selecting_reduction_backward) \
1577
+ _(aten, values) \
1578
+ _(aten, values_copy) \
1579
+ _(aten, vander) \
1580
+ _(aten, var) \
1581
+ _(aten, var_mean) \
1582
+ _(aten, vdot) \
1583
+ _(aten, view) \
1584
+ _(aten, view_as) \
1585
+ _(aten, view_as_complex) \
1586
+ _(aten, view_as_complex_copy) \
1587
+ _(aten, view_as_real) \
1588
+ _(aten, view_as_real_copy) \
1589
+ _(aten, view_copy) \
1590
+ _(aten, vsplit) \
1591
+ _(aten, vstack) \
1592
+ _(aten, where) \
1593
+ _(aten, xlogy) \
1594
+ _(aten, xlogy_) \
1595
+ _(aten, zero) \
1596
+ _(aten, zero_) \
1597
+ _(aten, zeros) \
1598
+ _(aten, zeros_like)
1599
+
1600
+ #define FORALL_ATTR_BASE_SYMBOLS(_) \
1601
+ _(attr, A) \
1602
+ _(attr, B) \
1603
+ _(attr, C) \
1604
+ _(attr, H) \
1605
+ _(attr, HxW) \
1606
+ _(attr, K) \
1607
+ _(attr, L) \
1608
+ _(attr, LD) \
1609
+ _(attr, LU) \
1610
+ _(attr, LU_data) \
1611
+ _(attr, LU_pivots) \
1612
+ _(attr, M) \
1613
+ _(attr, N) \
1614
+ _(attr, P) \
1615
+ _(attr, Q) \
1616
+ _(attr, R) \
1617
+ _(attr, S) \
1618
+ _(attr, U) \
1619
+ _(attr, UPLO) \
1620
+ _(attr, V) \
1621
+ _(attr, Vh) \
1622
+ _(attr, W) \
1623
+ _(attr, X) \
1624
+ _(attr, a) \
1625
+ _(attr, abs) \
1626
+ _(attr, accumulate) \
1627
+ _(attr, accumulate_matches) \
1628
+ _(attr, activation) \
1629
+ _(attr, addends) \
1630
+ _(attr, adjoint) \
1631
+ _(attr, alg_id) \
1632
+ _(attr, algorithm) \
1633
+ _(attr, alibi_slopes) \
1634
+ _(attr, align_corners) \
1635
+ _(attr, align_to_window) \
1636
+ _(attr, allow_tf32) \
1637
+ _(attr, alpha) \
1638
+ _(attr, amsgrad) \
1639
+ _(attr, anchor) \
1640
+ _(attr, angle) \
1641
+ _(attr, any) \
1642
+ _(attr, api_name) \
1643
+ _(attr, append) \
1644
+ _(attr, approximate) \
1645
+ _(attr, arg1) \
1646
+ _(attr, arg2) \
1647
+ _(attr, arg3) \
1648
+ _(attr, arg_out) \
1649
+ _(attr, assert_msg) \
1650
+ _(attr, assume_unique) \
1651
+ _(attr, atol) \
1652
+ _(attr, attn_bias) \
1653
+ _(attr, attn_mask) \
1654
+ _(attr, average_attn_weights) \
1655
+ _(attr, averaging_const) \
1656
+ _(attr, aweights) \
1657
+ _(attr, axis) \
1658
+ _(attr, axis0) \
1659
+ _(attr, axis1) \
1660
+ _(attr, b) \
1661
+ _(attr, b_hh) \
1662
+ _(attr, b_ih) \
1663
+ _(attr, bag_size) \
1664
+ _(attr, base) \
1665
+ _(attr, batch1) \
1666
+ _(attr, batch2) \
1667
+ _(attr, batch_dim) \
1668
+ _(attr, batch_first) \
1669
+ _(attr, batch_size) \
1670
+ _(attr, batch_sizes) \
1671
+ _(attr, benchmark) \
1672
+ _(attr, beta) \
1673
+ _(attr, beta1) \
1674
+ _(attr, beta2) \
1675
+ _(attr, bias) \
1676
+ _(attr, bias_defined) \
1677
+ _(attr, bias_g) \
1678
+ _(attr, bias_requires_grad) \
1679
+ _(attr, bias_sizes) \
1680
+ _(attr, bidirectional) \
1681
+ _(attr, bin_edges) \
1682
+ _(attr, bins) \
1683
+ _(attr, bit_width) \
1684
+ _(attr, blank) \
1685
+ _(attr, block_size) \
1686
+ _(attr, blocksize) \
1687
+ _(attr, boundaries) \
1688
+ _(attr, buffer) \
1689
+ _(attr, ccol_indices) \
1690
+ _(attr, cdim) \
1691
+ _(attr, cdist) \
1692
+ _(attr, ceil_mode) \
1693
+ _(attr, cell_state_fwd) \
1694
+ _(attr, center) \
1695
+ _(attr, ch_axis) \
1696
+ _(attr, check_errors) \
1697
+ _(attr, check_pinning) \
1698
+ _(attr, chunks) \
1699
+ _(attr, coalesced) \
1700
+ _(attr, coefficients) \
1701
+ _(attr, col) \
1702
+ _(attr, col_indices) \
1703
+ _(attr, col_offsets) \
1704
+ _(attr, col_offsets_hh) \
1705
+ _(attr, col_offsets_ih) \
1706
+ _(attr, compressed_A) \
1707
+ _(attr, compressed_idx) \
1708
+ _(attr, compressed_indices) \
1709
+ _(attr, compressed_indices_dtype) \
1710
+ _(attr, compute_log_sumexp) \
1711
+ _(attr, compute_mode) \
1712
+ _(attr, compute_uv) \
1713
+ _(attr, compute_v) \
1714
+ _(attr, condition) \
1715
+ _(attr, contraction_dim) \
1716
+ _(attr, copy) \
1717
+ _(attr, correction) \
1718
+ _(attr, count) \
1719
+ _(attr, count_include_pad) \
1720
+ _(attr, counts) \
1721
+ _(attr, cpu_dtype) \
1722
+ _(attr, cpu_enabled) \
1723
+ _(attr, cpu_nested_shape_example) \
1724
+ _(attr, create_graph) \
1725
+ _(attr, crow_indices) \
1726
+ _(attr, cu_seqlens_k) \
1727
+ _(attr, cu_seqlens_q) \
1728
+ _(attr, cuda_dtype) \
1729
+ _(attr, cuda_enabled) \
1730
+ _(attr, cudnn_enable) \
1731
+ _(attr, cudnn_enabled) \
1732
+ _(attr, cum_seq_k) \
1733
+ _(attr, cum_seq_q) \
1734
+ _(attr, custom_mask_type) \
1735
+ _(attr, cx) \
1736
+ _(attr, cx_) \
1737
+ _(attr, cx_tmp) \
1738
+ _(attr, cy) \
1739
+ _(attr, cy_) \
1740
+ _(attr, d) \
1741
+ _(attr, dampening) \
1742
+ _(attr, data) \
1743
+ _(attr, decimals) \
1744
+ _(attr, delta) \
1745
+ _(attr, dense) \
1746
+ _(attr, dense_B) \
1747
+ _(attr, dense_dim) \
1748
+ _(attr, density) \
1749
+ _(attr, dep_token) \
1750
+ _(attr, descending) \
1751
+ _(attr, destination) \
1752
+ _(attr, deterministic) \
1753
+ _(attr, device) \
1754
+ _(attr, device_index) \
1755
+ _(attr, dgrad_glu) \
1756
+ _(attr, diagonal) \
1757
+ _(attr, diagonals) \
1758
+ _(attr, dilation) \
1759
+ _(attr, dim) \
1760
+ _(attr, dim0) \
1761
+ _(attr, dim1) \
1762
+ _(attr, dim2) \
1763
+ _(attr, dimension) \
1764
+ _(attr, dims) \
1765
+ _(attr, dims_other) \
1766
+ _(attr, dims_self) \
1767
+ _(attr, divisor_override) \
1768
+ _(attr, downscale_factor) \
1769
+ _(attr, driver) \
1770
+ _(attr, dropout) \
1771
+ _(attr, dropout_mask) \
1772
+ _(attr, dropout_p) \
1773
+ _(attr, dropout_seed) \
1774
+ _(attr, dropout_state) \
1775
+ _(attr, dst) \
1776
+ _(attr, dtype) \
1777
+ _(attr, dual) \
1778
+ _(attr, dummy) \
1779
+ _(attr, dx) \
1780
+ _(attr, edge_order) \
1781
+ _(attr, eigenvalues) \
1782
+ _(attr, eigenvectors) \
1783
+ _(attr, eigvals) \
1784
+ _(attr, eigvecs) \
1785
+ _(attr, element) \
1786
+ _(attr, elements) \
1787
+ _(attr, ellipsis_idx) \
1788
+ _(attr, embed_dim) \
1789
+ _(attr, enable_gqa) \
1790
+ _(attr, end) \
1791
+ _(attr, end_dim) \
1792
+ _(attr, eps) \
1793
+ _(attr, epsilon) \
1794
+ _(attr, equal_nan) \
1795
+ _(attr, equation) \
1796
+ _(attr, exp_avg_sqs) \
1797
+ _(attr, exp_avgs) \
1798
+ _(attr, expand1) \
1799
+ _(attr, expand2) \
1800
+ _(attr, expand3) \
1801
+ _(attr, exponent) \
1802
+ _(attr, exponential_average_factor) \
1803
+ _(attr, fake_quant_enabled) \
1804
+ _(attr, fake_quant_on) \
1805
+ _(attr, ffn_bias_1) \
1806
+ _(attr, ffn_bias_2) \
1807
+ _(attr, ffn_weight_1) \
1808
+ _(attr, ffn_weight_2) \
1809
+ _(attr, filename) \
1810
+ _(attr, fill) \
1811
+ _(attr, fill_value) \
1812
+ _(attr, flat) \
1813
+ _(attr, forward) \
1814
+ _(attr, found_inf) \
1815
+ _(attr, from) \
1816
+ _(attr, from_) \
1817
+ _(attr, full) \
1818
+ _(attr, full_matrices) \
1819
+ _(attr, fuse_transform_0213) \
1820
+ _(attr, fweights) \
1821
+ _(attr, g) \
1822
+ _(attr, gO) \
1823
+ _(attr, generator) \
1824
+ _(attr, ggI) \
1825
+ _(attr, ggW) \
1826
+ _(attr, ggb) \
1827
+ _(attr, glu) \
1828
+ _(attr, grad) \
1829
+ _(attr, grad_bias) \
1830
+ _(attr, grad_cy) \
1831
+ _(attr, grad_factor) \
1832
+ _(attr, grad_glu) \
1833
+ _(attr, grad_hy) \
1834
+ _(attr, grad_in) \
1835
+ _(attr, grad_input) \
1836
+ _(attr, grad_input_mask) \
1837
+ _(attr, grad_out) \
1838
+ _(attr, grad_out_) \
1839
+ _(attr, grad_output) \
1840
+ _(attr, grad_scale) \
1841
+ _(attr, grad_w) \
1842
+ _(attr, grad_weight) \
1843
+ _(attr, grad_x) \
1844
+ _(attr, grad_y) \
1845
+ _(attr, gradient) \
1846
+ _(attr, grads) \
1847
+ _(attr, grid) \
1848
+ _(attr, group) \
1849
+ _(attr, groups) \
1850
+ _(attr, growth_interval) \
1851
+ _(attr, growth_tracker) \
1852
+ _(attr, half_to_float) \
1853
+ _(attr, has_bias) \
1854
+ _(attr, has_biases) \
1855
+ _(attr, hermitian) \
1856
+ _(attr, hidden_bias) \
1857
+ _(attr, hidden_gates) \
1858
+ _(attr, hidden_size) \
1859
+ _(attr, high) \
1860
+ _(attr, hist) \
1861
+ _(attr, hop_length) \
1862
+ _(attr, hx) \
1863
+ _(attr, hx_) \
1864
+ _(attr, hy_) \
1865
+ _(attr, i1) \
1866
+ _(attr, i2) \
1867
+ _(attr, i3) \
1868
+ _(attr, ignore_index) \
1869
+ _(attr, imag) \
1870
+ _(attr, impl_index) \
1871
+ _(attr, implicit) \
1872
+ _(attr, in_features) \
1873
+ _(attr, include_last_offset) \
1874
+ _(attr, include_self) \
1875
+ _(attr, increasing) \
1876
+ _(attr, ind) \
1877
+ _(attr, index) \
1878
+ _(attr, index_dtype) \
1879
+ _(attr, indexing) \
1880
+ _(attr, indices) \
1881
+ _(attr, info) \
1882
+ _(attr, initial) \
1883
+ _(attr, innerKTiles) \
1884
+ _(attr, inp) \
1885
+ _(attr, input) \
1886
+ _(attr, input1) \
1887
+ _(attr, input2) \
1888
+ _(attr, input3) \
1889
+ _(attr, input_bias) \
1890
+ _(attr, input_dtype) \
1891
+ _(attr, input_g) \
1892
+ _(attr, input_gates) \
1893
+ _(attr, input_lengths) \
1894
+ _(attr, input_scale) \
1895
+ _(attr, input_size) \
1896
+ _(attr, input_sizes) \
1897
+ _(attr, input_zero_point) \
1898
+ _(attr, inputs) \
1899
+ _(attr, interpolation) \
1900
+ _(attr, interpolation_mode) \
1901
+ _(attr, inv_scale) \
1902
+ _(attr, inverse) \
1903
+ _(attr, invert) \
1904
+ _(attr, invstd) \
1905
+ _(attr, is_causal) \
1906
+ _(attr, is_coalesced) \
1907
+ _(attr, is_crow) \
1908
+ _(attr, is_first_step) \
1909
+ _(attr, is_matrix) \
1910
+ _(attr, is_result) \
1911
+ _(attr, is_target) \
1912
+ _(attr, k) \
1913
+ _(attr, keepdim) \
1914
+ _(attr, kernel_size) \
1915
+ _(attr, key) \
1916
+ _(attr, label_smoothing) \
1917
+ _(attr, lambd) \
1918
+ _(attr, largest) \
1919
+ _(attr, last_dim_size) \
1920
+ _(attr, layersOutputs) \
1921
+ _(attr, layout) \
1922
+ _(attr, left) \
1923
+ _(attr, length) \
1924
+ _(attr, lengths) \
1925
+ _(attr, level) \
1926
+ _(attr, like) \
1927
+ _(attr, list) \
1928
+ _(attr, log_alpha) \
1929
+ _(attr, log_input) \
1930
+ _(attr, log_probs) \
1931
+ _(attr, log_target) \
1932
+ _(attr, logabsdet) \
1933
+ _(attr, logsumexp) \
1934
+ _(attr, low) \
1935
+ _(attr, lower) \
1936
+ _(attr, lr) \
1937
+ _(attr, lr_decay) \
1938
+ _(attr, ltm) \
1939
+ _(attr, m) \
1940
+ _(attr, mantissa) \
1941
+ _(attr, margin) \
1942
+ _(attr, mask) \
1943
+ _(attr, mask_check) \
1944
+ _(attr, mask_type) \
1945
+ _(attr, masked_grad) \
1946
+ _(attr, mat) \
1947
+ _(attr, mat1) \
1948
+ _(attr, mat1_meta) \
1949
+ _(attr, mat2) \
1950
+ _(attr, matrices) \
1951
+ _(attr, max) \
1952
+ _(attr, max_exp_avg_sqs) \
1953
+ _(attr, max_k) \
1954
+ _(attr, max_lengths) \
1955
+ _(attr, max_norm) \
1956
+ _(attr, max_q) \
1957
+ _(attr, max_seqlen) \
1958
+ _(attr, max_seqlen_k) \
1959
+ _(attr, max_seqlen_q) \
1960
+ _(attr, max_size) \
1961
+ _(attr, max_val) \
1962
+ _(attr, max_values) \
1963
+ _(attr, maximize) \
1964
+ _(attr, maximum_indices) \
1965
+ _(attr, maxnorm) \
1966
+ _(attr, mean) \
1967
+ _(attr, median) \
1968
+ _(attr, memory_format) \
1969
+ _(attr, meta) \
1970
+ _(attr, min) \
1971
+ _(attr, min_indices) \
1972
+ _(attr, min_seqlen) \
1973
+ _(attr, min_val) \
1974
+ _(attr, minlength) \
1975
+ _(attr, mode) \
1976
+ _(attr, momentum) \
1977
+ _(attr, momentum_buffer_list) \
1978
+ _(attr, n) \
1979
+ _(attr, n_bins) \
1980
+ _(attr, n_fft) \
1981
+ _(attr, names) \
1982
+ _(attr, nan) \
1983
+ _(attr, need_weights) \
1984
+ _(attr, neg_log_likelihood) \
1985
+ _(attr, negative) \
1986
+ _(attr, negative_slope) \
1987
+ _(attr, neginf) \
1988
+ _(attr, nested_size) \
1989
+ _(attr, nested_strides) \
1990
+ _(attr, nesterov) \
1991
+ _(attr, new_data) \
1992
+ _(attr, nnz) \
1993
+ _(attr, noise) \
1994
+ _(attr, non_blocking) \
1995
+ _(attr, norm) \
1996
+ _(attr, norm_bias_1) \
1997
+ _(attr, norm_bias_2) \
1998
+ _(attr, norm_first) \
1999
+ _(attr, norm_type) \
2000
+ _(attr, norm_weight_1) \
2001
+ _(attr, norm_weight_2) \
2002
+ _(attr, normalization) \
2003
+ _(attr, normalized) \
2004
+ _(attr, normalized_shape) \
2005
+ _(attr, nt_example) \
2006
+ _(attr, num_chunks) \
2007
+ _(attr, num_classes) \
2008
+ _(attr, num_generated) \
2009
+ _(attr, num_groups) \
2010
+ _(attr, num_head) \
2011
+ _(attr, num_heads) \
2012
+ _(attr, num_layers) \
2013
+ _(attr, num_parallel) \
2014
+ _(attr, num_samples) \
2015
+ _(attr, num_splits_key) \
2016
+ _(attr, num_weights) \
2017
+ _(attr, numel) \
2018
+ _(attr, observer_on) \
2019
+ _(attr, offs) \
2020
+ _(attr, offset) \
2021
+ _(attr, offset2bag) \
2022
+ _(attr, offsets) \
2023
+ _(attr, onesided) \
2024
+ _(attr, ord) \
2025
+ _(attr, order) \
2026
+ _(attr, other) \
2027
+ _(attr, out) \
2028
+ _(attr, out0) \
2029
+ _(attr, out1) \
2030
+ _(attr, out2) \
2031
+ _(attr, out3) \
2032
+ _(attr, out4) \
2033
+ _(attr, out5) \
2034
+ _(attr, out6) \
2035
+ _(attr, out_channel) \
2036
+ _(attr, out_dim) \
2037
+ _(attr, out_dtype) \
2038
+ _(attr, out_features) \
2039
+ _(attr, out_int32) \
2040
+ _(attr, outdim) \
2041
+ _(attr, output) \
2042
+ _(attr, output_mask) \
2043
+ _(attr, output_padding) \
2044
+ _(attr, output_scale) \
2045
+ _(attr, output_size) \
2046
+ _(attr, output_zero_point) \
2047
+ _(attr, p) \
2048
+ _(attr, packed) \
2049
+ _(attr, packed_hh) \
2050
+ _(attr, packed_ih) \
2051
+ _(attr, packed_weight) \
2052
+ _(attr, packed_weights) \
2053
+ _(attr, pad) \
2054
+ _(attr, pad_mode) \
2055
+ _(attr, padded) \
2056
+ _(attr, padding) \
2057
+ _(attr, padding_idx) \
2058
+ _(attr, padding_mode) \
2059
+ _(attr, padding_side) \
2060
+ _(attr, padding_value) \
2061
+ _(attr, params) \
2062
+ _(attr, path) \
2063
+ _(attr, pdist) \
2064
+ _(attr, per_row_fake_quant) \
2065
+ _(attr, per_sample_weights) \
2066
+ _(attr, periodic) \
2067
+ _(attr, philox_offset) \
2068
+ _(attr, philox_seed) \
2069
+ _(attr, physical_layout) \
2070
+ _(attr, pin_memory) \
2071
+ _(attr, pivot) \
2072
+ _(attr, pivots) \
2073
+ _(attr, plain_idx) \
2074
+ _(attr, plain_indices) \
2075
+ _(attr, pos_weight) \
2076
+ _(attr, posinf) \
2077
+ _(attr, positive) \
2078
+ _(attr, pow) \
2079
+ _(attr, prepend) \
2080
+ _(attr, primal) \
2081
+ _(attr, prob) \
2082
+ _(attr, proj_bias) \
2083
+ _(attr, proj_size) \
2084
+ _(attr, proj_weight) \
2085
+ _(attr, q) \
2086
+ _(attr, qGroupSize) \
2087
+ _(attr, qScale) \
2088
+ _(attr, qScaleAndZeros) \
2089
+ _(attr, qZeros) \
2090
+ _(attr, qkv) \
2091
+ _(attr, qkv_bias) \
2092
+ _(attr, qkv_weight) \
2093
+ _(attr, qtensor) \
2094
+ _(attr, quant_max) \
2095
+ _(attr, quant_min) \
2096
+ _(attr, quasi) \
2097
+ _(attr, query) \
2098
+ _(attr, r) \
2099
+ _(attr, ragged_idx) \
2100
+ _(attr, random_samples) \
2101
+ _(attr, range) \
2102
+ _(attr, rank) \
2103
+ _(attr, ratio) \
2104
+ _(attr, rcond) \
2105
+ _(attr, real) \
2106
+ _(attr, recipe_a) \
2107
+ _(attr, recipe_b) \
2108
+ _(attr, reduce) \
2109
+ _(attr, reduce_range) \
2110
+ _(attr, reduction) \
2111
+ _(attr, repeats) \
2112
+ _(attr, replacement) \
2113
+ _(attr, requires_grad) \
2114
+ _(attr, reserve) \
2115
+ _(attr, reserveSpace) \
2116
+ _(attr, reservedSpace) \
2117
+ _(attr, residuals) \
2118
+ _(attr, result) \
2119
+ _(attr, retain_graph) \
2120
+ _(attr, return_complex) \
2121
+ _(attr, return_counts) \
2122
+ _(attr, return_debug_mask) \
2123
+ _(attr, return_inverse) \
2124
+ _(attr, reverse) \
2125
+ _(attr, right) \
2126
+ _(attr, rng_state) \
2127
+ _(attr, rounding_mode) \
2128
+ _(attr, row) \
2129
+ _(attr, row_indices) \
2130
+ _(attr, rstd) \
2131
+ _(attr, rtol) \
2132
+ _(attr, running_max) \
2133
+ _(attr, running_mean) \
2134
+ _(attr, running_min) \
2135
+ _(attr, running_var) \
2136
+ _(attr, s) \
2137
+ _(attr, save_invstd) \
2138
+ _(attr, save_mean) \
2139
+ _(attr, save_var) \
2140
+ _(attr, save_var_transform) \
2141
+ _(attr, saved_g) \
2142
+ _(attr, saved_norms) \
2143
+ _(attr, saved_v) \
2144
+ _(attr, scalar) \
2145
+ _(attr, scalar1) \
2146
+ _(attr, scalar2) \
2147
+ _(attr, scalars) \
2148
+ _(attr, scale) \
2149
+ _(attr, scale_a) \
2150
+ _(attr, scale_b) \
2151
+ _(attr, scale_backoff_factor) \
2152
+ _(attr, scale_factors) \
2153
+ _(attr, scale_grad_by_freq) \
2154
+ _(attr, scale_growth_factor) \
2155
+ _(attr, scale_hh) \
2156
+ _(attr, scale_ih) \
2157
+ _(attr, scale_result) \
2158
+ _(attr, scales) \
2159
+ _(attr, scales_d) \
2160
+ _(attr, scales_h) \
2161
+ _(attr, scales_w) \
2162
+ _(attr, scales_zeros) \
2163
+ _(attr, sections) \
2164
+ _(attr, seed) \
2165
+ _(attr, self) \
2166
+ _(attr, self_is_result) \
2167
+ _(attr, self_num_batch_dims) \
2168
+ _(attr, self_or_result) \
2169
+ _(attr, self_sizes) \
2170
+ _(attr, seqlen_k) \
2171
+ _(attr, sequences) \
2172
+ _(attr, seqused_k) \
2173
+ _(attr, shape) \
2174
+ _(attr, shared) \
2175
+ _(attr, shared_storage_dqdkdv) \
2176
+ _(attr, shifts) \
2177
+ _(attr, side) \
2178
+ _(attr, sigma) \
2179
+ _(attr, sign) \
2180
+ _(attr, singular_values) \
2181
+ _(attr, size) \
2182
+ _(attr, sizes) \
2183
+ _(attr, skip_first) \
2184
+ _(attr, sobolstate) \
2185
+ _(attr, solution) \
2186
+ _(attr, some) \
2187
+ _(attr, sorted) \
2188
+ _(attr, sorted_sequence) \
2189
+ _(attr, sorter) \
2190
+ _(attr, source) \
2191
+ _(attr, spacing) \
2192
+ _(attr, sparse) \
2193
+ _(attr, sparse_dim) \
2194
+ _(attr, sparse_grad) \
2195
+ _(attr, split_k) \
2196
+ _(attr, split_k_mode) \
2197
+ _(attr, split_size) \
2198
+ _(attr, split_sizes) \
2199
+ _(attr, src) \
2200
+ _(attr, stable) \
2201
+ _(attr, start) \
2202
+ _(attr, start_dim) \
2203
+ _(attr, state_steps) \
2204
+ _(attr, state_sums) \
2205
+ _(attr, std) \
2206
+ _(attr, step) \
2207
+ _(attr, steps) \
2208
+ _(attr, storage_offset) \
2209
+ _(attr, stride) \
2210
+ _(attr, sum_S) \
2211
+ _(attr, sum_dy) \
2212
+ _(attr, sum_dy_xmu) \
2213
+ _(attr, sumdim) \
2214
+ _(attr, swap) \
2215
+ _(attr, swizzle_a) \
2216
+ _(attr, swizzle_b) \
2217
+ _(attr, symmetric_quant) \
2218
+ _(attr, t) \
2219
+ _(attr, tangent) \
2220
+ _(attr, target) \
2221
+ _(attr, target_lengths) \
2222
+ _(attr, targets) \
2223
+ _(attr, tau) \
2224
+ _(attr, tensor) \
2225
+ _(attr, tensor1) \
2226
+ _(attr, tensor2) \
2227
+ _(attr, tensor_indices_or_sections) \
2228
+ _(attr, tensors) \
2229
+ _(attr, tensors1) \
2230
+ _(attr, test_element) \
2231
+ _(attr, test_elements) \
2232
+ _(attr, the_template) \
2233
+ _(attr, theta) \
2234
+ _(attr, thread_masks) \
2235
+ _(attr, threshold) \
2236
+ _(attr, to) \
2237
+ _(attr, tol) \
2238
+ _(attr, total) \
2239
+ _(attr, total_L) \
2240
+ _(attr, total_length) \
2241
+ _(attr, total_weight) \
2242
+ _(attr, train) \
2243
+ _(attr, training) \
2244
+ _(attr, transpose) \
2245
+ _(attr, transpose_result) \
2246
+ _(attr, transposed) \
2247
+ _(attr, type1) \
2248
+ _(attr, type2) \
2249
+ _(attr, unbiased) \
2250
+ _(attr, unitriangular) \
2251
+ _(attr, unpack_data) \
2252
+ _(attr, unpack_pivots) \
2253
+ _(attr, unroll_dim) \
2254
+ _(attr, unsafe) \
2255
+ _(attr, unused) \
2256
+ _(attr, update) \
2257
+ _(attr, upper) \
2258
+ _(attr, upscale_factor) \
2259
+ _(attr, use_cutlass) \
2260
+ _(attr, use_fast_accum) \
2261
+ _(attr, use_gelu) \
2262
+ _(attr, use_input_stats) \
2263
+ _(attr, v) \
2264
+ _(attr, value) \
2265
+ _(attr, values) \
2266
+ _(attr, var) \
2267
+ _(attr, vec) \
2268
+ _(attr, vec1) \
2269
+ _(attr, vec2) \
2270
+ _(attr, w_hh) \
2271
+ _(attr, w_ih) \
2272
+ _(attr, weight) \
2273
+ _(attr, weight0) \
2274
+ _(attr, weight1) \
2275
+ _(attr, weight2) \
2276
+ _(attr, weight3) \
2277
+ _(attr, weight4) \
2278
+ _(attr, weight_arr) \
2279
+ _(attr, weight_buf) \
2280
+ _(attr, weight_decay) \
2281
+ _(attr, weight_g) \
2282
+ _(attr, weight_scale) \
2283
+ _(attr, weight_stride0) \
2284
+ _(attr, weight_zero_point) \
2285
+ _(attr, weights) \
2286
+ _(attr, win_length) \
2287
+ _(attr, window) \
2288
+ _(attr, window_length) \
2289
+ _(attr, window_size) \
2290
+ _(attr, window_size_left) \
2291
+ _(attr, window_size_right) \
2292
+ _(attr, with_replacement) \
2293
+ _(attr, workspace) \
2294
+ _(attr, wrap) \
2295
+ _(attr, x) \
2296
+ _(attr, x1) \
2297
+ _(attr, x2) \
2298
+ _(attr, y) \
2299
+ _(attr, z) \
2300
+ _(attr, z_state) \
2301
+ _(attr, zero_infinity) \
2302
+ _(attr, zero_point) \
2303
+ _(attr, zero_point_hh) \
2304
+ _(attr, zero_point_ih) \
2305
+ _(attr, zero_points)
2306
+
2307
+ #else
2308
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
2309
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <type_traits>
5
+
6
+ #include <c10/util/intrusive_ptr.h>
7
+ #include <c10/util/typeid.h>
8
+ #include <c10/macros/Macros.h>
9
+
10
+ namespace caffe2 {
11
+
12
+ class Tensor;
13
+
14
+ /**
15
+ * @brief Blob is a general container that hosts a typed pointer.
16
+ *
17
+ * A Blob hosts a pointer as well as its type, and takes charge of deleting it
18
+ * properly when the blob is deallocated or re-allocated with a new type. A blob
19
+ * could contain anything, although the most common case is to contain a Tensor.
20
+ */
21
+ class TORCH_API Blob final : public c10::intrusive_ptr_target {
22
+ public:
23
+ /**
24
+ * Initializes an empty Blob.
25
+ */
26
+ Blob() noexcept = default;
27
+ ~Blob() override {
28
+ Reset();
29
+ }
30
+
31
+ Blob(Blob&& other) noexcept : Blob() {
32
+ swap(other);
33
+ }
34
+
35
+ Blob& operator=(Blob&& other) noexcept {
36
+ Blob(std::move(other)).swap(*this);
37
+ return *this;
38
+ }
39
+
40
+ /**
41
+ * Checks if the content stored in the blob is of type T.
42
+ */
43
+ template <class T>
44
+ bool IsType() const noexcept {
45
+ return meta_.Match<T>();
46
+ }
47
+
48
+ /**
49
+ * Returns the meta info of the blob.
50
+ */
51
+ const TypeMeta meta() const noexcept {
52
+ return meta_;
53
+ }
54
+
55
+ /**
56
+ * Returns a printable typename of the blob.
57
+ */
58
+ std::string_view TypeName() const noexcept {
59
+ return meta_.name();
60
+ }
61
+
62
+ /**
63
+ * @brief Gets the const reference of the stored object. The code checks if
64
+ * the stored object is of the desired type.
65
+ */
66
+ // TODO(jerryzh): add a Get(c10::DeviceType) function?
67
+ template <class T>
68
+ const T& Get() const {
69
+ TORCH_INTERNAL_ASSERT(
70
+ IsType<T>(),
71
+ "wrong type for the Blob instance. Blob contains ",
72
+ meta_.name(),
73
+ " while caller expects ",
74
+ TypeMeta::TypeName<T>());
75
+ // TODO: after we add Get<Tensor>(c10::DeviceType)
76
+ // and changed all the callsites, we can add
77
+ // a static assert here to enforce T != Tensor
78
+ return *static_cast<const T*>(pointer_);
79
+ }
80
+
81
+ const void* GetRaw() const noexcept {
82
+ return pointer_;
83
+ }
84
+ void* GetRaw() noexcept {
85
+ return pointer_;
86
+ }
87
+
88
+ /**
89
+ * @brief Gets a mutable pointer to the stored object.
90
+ *
91
+ * If the current object is not of the right type, a new object is created
92
+ * and the old object is freed. Note that type T should have a default
93
+ * constructor. Otherwise, create the object yourself first, and use
94
+ * Reset().
95
+ */
96
+ template <class T>
97
+ T* GetMutable() {
98
+ static_assert(
99
+ std::is_default_constructible_v<T>,
100
+ "GetMutable can't be called with non-default-constructible types. "
101
+ "Try using specialized methods");
102
+ if (IsType<T>()) {
103
+ return static_cast<T*>(pointer_);
104
+ } else {
105
+ // TODO Re-enable logging
106
+ // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
107
+ return Reset<T>(new T());
108
+ }
109
+ }
110
+
111
+ template <class T>
112
+ T* GetMutableOrNull() {
113
+ if (IsType<T>()) {
114
+ return static_cast<T*>(pointer_);
115
+ } else {
116
+ return nullptr;
117
+ }
118
+ }
119
+
120
+ /**
121
+ * Sets the underlying object to the allocated one. The Blob then takes over
122
+ * the ownership of the passed in pointer. If there is already an object in
123
+ * the Blob, the old object is freed.
124
+ *
125
+ * This is used when the underlying class T does not have a default ctor, or
126
+ * complex initializations needs to be done outside the blob.
127
+ */
128
+ template <class T>
129
+ T* Reset(T* allocated) {
130
+ free_();
131
+ meta_ = TypeMeta::Make<T>();
132
+ pointer_ = static_cast<void*>(allocated);
133
+ has_ownership_ = true;
134
+ return allocated;
135
+ }
136
+
137
+ /**
138
+ * Sets the underlying object to the allocated one, but does not take over
139
+ * the ownership of the passed in pointer. If there is already an object in
140
+ * the Blob, the old object is freed.
141
+ *
142
+ * Unlike Reset, this does not take over the ownership of the pointer and the
143
+ * caller is responsible for making sure that the lifetime of the allocated
144
+ * blob outlasts the lifetime of any access to this blob, until another Reset
145
+ * call is made or the blob is destructed.
146
+ */
147
+ template <class T>
148
+ std::remove_const_t<T>* ShareExternal(
149
+ std::remove_const_t<T>* allocated) {
150
+ return static_cast<T*>(ShareExternal(
151
+ static_cast<void*>(allocated),
152
+ TypeMeta::Make<std::remove_const_t<T>>()));
153
+ }
154
+
155
+ void* ShareExternal(void* allocated, const TypeMeta meta) {
156
+ free_();
157
+ meta_ = meta;
158
+ pointer_ = allocated;
159
+ has_ownership_ = false;
160
+ return allocated;
161
+ }
162
+
163
+ /**
164
+ * Resets the Blob to an empty one.
165
+ */
166
+ void Reset() {
167
+ free_();
168
+ pointer_ = nullptr;
169
+ meta_ = TypeMeta();
170
+ has_ownership_ = false;
171
+ }
172
+
173
+ /**
174
+ * @brief Swaps the underlying storage of two blobs.
175
+ */
176
+ void swap(Blob& rhs) noexcept {
177
+ using std::swap;
178
+ swap(meta_, rhs.meta_);
179
+ swap(pointer_, rhs.pointer_);
180
+ swap(has_ownership_, rhs.has_ownership_);
181
+ }
182
+
183
+ private:
184
+ void free_() {
185
+ if (has_ownership_ && pointer_ != nullptr) {
186
+ (*meta_.deleteFn())(pointer_);
187
+ }
188
+ }
189
+
190
+ TypeMeta meta_;
191
+ void* pointer_{nullptr};
192
+ bool has_ownership_{false};
193
+
194
+ C10_DISABLE_COPY_AND_ASSIGN(Blob);
195
+ };
196
+
197
+ inline void swap(Blob& lhs, Blob& rhs) noexcept {
198
+ lhs.swap(rhs);
199
+ }
200
+
201
+ inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
202
+ return out << "Blob[" << v.TypeName() << ']';
203
+ }
204
+
205
+ } // namespace caffe2
206
+
207
+ #else
208
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
209
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/builtin_function.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/function.h>
5
+ #include <ATen/core/ivalue.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/intrusive_ptr.h>
8
+ #include <functional>
9
+ #include <utility>
10
+
11
+ namespace torch::jit {
12
+
13
+ struct BuiltinOpFunction : public Function {
14
+ BuiltinOpFunction(
15
+ c10::QualifiedName qualname,
16
+ c10::FunctionSchema schema,
17
+ std::function<void(Stack&)> callable,
18
+ std::string doc_string = "")
19
+ : name_(std::move(qualname)),
20
+ callable_(std::move(callable)),
21
+ schema_(std::move(schema)),
22
+ doc_string_(std::move(doc_string)) {
23
+ TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
24
+ }
25
+
26
+ std::string_view doc_string() const override {
27
+ return doc_string_;
28
+ }
29
+
30
+ void run(Stack& stack) override {
31
+ callable_(stack);
32
+ }
33
+
34
+ c10::intrusive_ptr<c10::ivalue::Future> runAsync(
35
+ Stack& stack,
36
+ TaskLauncher /* not used */) override {
37
+ run(stack);
38
+ auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
39
+ res->markCompleted(std::move(stack.front()));
40
+ return res;
41
+ }
42
+
43
+ const c10::QualifiedName& qualname() const override {
44
+ return name_;
45
+ }
46
+
47
+ // if this isn't yet defined, run its method_creator function
48
+ void ensure_defined() override {
49
+ // nop
50
+ }
51
+
52
+ const c10::FunctionSchema& getSchema() const override {
53
+ return schema_;
54
+ }
55
+
56
+ size_t num_inputs() const override {
57
+ return schema_.arguments().size();
58
+ }
59
+
60
+ Function& setSchema(c10::FunctionSchema schema) override {
61
+ schema_ = std::move(schema);
62
+ return *this;
63
+ }
64
+
65
+ bool call(
66
+ Stack& stack,
67
+ std::optional<size_t> /*unused*/,
68
+ c10::function_ref<void(const Code&)> /*unused*/) override {
69
+ run(stack);
70
+ return false;
71
+ }
72
+
73
+ bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)> /*unused*/)
74
+ override {
75
+ run(stack);
76
+ return false;
77
+ }
78
+
79
+ ~BuiltinOpFunction() override = default;
80
+
81
+ private:
82
+ c10::QualifiedName name_;
83
+
84
+ std::function<void(Stack&)> callable_;
85
+
86
+ c10::FunctionSchema schema_;
87
+
88
+ std::string doc_string_;
89
+ };
90
+
91
+ } // namespace torch::jit
92
+
93
+ #else
94
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
95
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/core/class_type.h ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <memory>
5
+
6
+ #include <ATen/core/ivalue.h>
7
+ #include <ATen/core/jit_type_base.h>
8
+ #include <optional>
9
+
10
+
11
+ namespace torch::jit {
12
+ struct CompilationUnit;
13
+ struct Function;
14
+ } // namespace torch::jit
15
+
16
+
17
+ namespace c10 {
18
+
19
+ struct FunctionSchema;
20
+
21
+ // This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither.
22
+ // This state is mutually exclusive. Buffers and Parameters can only appear on modules.
23
+ enum class AttributeKind {
24
+ BUFFER,
25
+ PARAMETER,
26
+ REGULAR_ATTRIBUTE
27
+ };
28
+
29
+ // This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr).
30
+ // Note: This structure does not represent the value of the attribute.
31
+ struct TORCH_API ClassAttribute {
32
+ public:
33
+ ClassAttribute(AttributeKind kind,
34
+ TypePtr attributeType,
35
+ std::string attributeName) :
36
+ kind_(kind),
37
+ attributeType_(std::move(attributeType)),
38
+ attributeName_(std::move(attributeName)) {}
39
+
40
+ AttributeKind getKind() const {
41
+ return kind_;
42
+ }
43
+
44
+ const TypePtr& getType() const {
45
+ return attributeType_;
46
+ }
47
+
48
+ const std::string& getName() const {
49
+ return attributeName_;
50
+ }
51
+
52
+ private:
53
+ AttributeKind kind_;
54
+ TypePtr attributeType_;
55
+ std::string attributeName_;
56
+ };
57
+
58
+ /**
59
+ * User Defined Types
60
+ */
61
+
62
+ struct ClassType;
63
+ using ClassTypePtr = std::shared_ptr<ClassType>;
64
+ using ::torch::jit::CompilationUnit;
65
+
66
+ // This represents a class in TorchScript.
67
+ struct TORCH_API ClassType : public NamedType {
68
+ // This represents an attribute of a class; a name associated with an attribute, and a
69
+ // getter and (optional) setter for that attribute.
70
+ struct Property {
71
+ std::string name;
72
+ torch::jit::Function* getter;
73
+ torch::jit::Function* setter;
74
+ };
75
+
76
+ // Create a class type with name `name` and its methods stored in `cu`.
77
+ static ClassTypePtr create(
78
+ std::optional<QualifiedName> qualifiedName,
79
+ std::weak_ptr<CompilationUnit> cu,
80
+ bool is_module = false,
81
+ std::string doc_string = "",
82
+ std::vector<std::string> unresolved_class_attributes = {});
83
+
84
+ bool equals(const Type& rhs) const override {
85
+ if (this == &rhs) {
86
+ return true;
87
+ }
88
+ if (auto user_rhs = rhs.castRaw<ClassType>()) {
89
+ const auto& lhs_name = name();
90
+ const auto& rhs_name = user_rhs->name();
91
+ return lhs_name.has_value() && lhs_name == rhs_name &&
92
+ this->compilation_unit() == user_rhs->compilation_unit();
93
+ }
94
+ return false;
95
+ }
96
+
97
+ std::string str() const override {
98
+ return annotation_str();
99
+ }
100
+
101
+ std::string repr_str() const override {
102
+ std::stringstream ss;
103
+ ss << str()
104
+ << " (of Python compilation unit at: " << compilation_unit().get() << ')';
105
+ return ss.str();
106
+ }
107
+
108
+ const std::vector<torch::jit::Function*>& methods() const;
109
+
110
+ TypePtr findAttribute(const std::string& name) const {
111
+ size_t pos = 0;
112
+ for (const auto& attr : attributes_) {
113
+ if (name == attr.getName()) {
114
+ break;
115
+ }
116
+ ++pos;
117
+ }
118
+
119
+ if (pos >= attributes_.size()) {
120
+ return nullptr;
121
+ }
122
+ return attributes_[pos].getType();
123
+ }
124
+
125
+ const TypePtr& getAttribute(const std::string& name) const {
126
+ auto slot = findAttributeSlot(name);
127
+ TORCH_CHECK(
128
+ slot,
129
+ repr_str(),
130
+ " does not have an attribute with name '",
131
+ name,
132
+ "'");
133
+ return attributes_[*slot].getType();
134
+ }
135
+
136
+ size_t numAttributes() const {
137
+ return attributes_.size();
138
+ }
139
+
140
+ const TypePtr& getAttribute(size_t slot) const {
141
+ AT_ASSERT(slot < attributes_.size());
142
+ return attributes_.at(slot).getType();
143
+ }
144
+
145
+ const std::string getAttributeName(size_t slot) const {
146
+ AT_ASSERT(slot < attributes_.size());
147
+ return attributes_[slot].getName();
148
+ }
149
+
150
+ void checkNotExist(const std::string& name, const std::string& what) const;
151
+
152
+ // Attributes are stored in a specific slot at runtime for efficiency.
153
+ // When emitting instructions we specify the slot so that attribute access is
154
+ // a constant lookup
155
+ std::optional<size_t> findAttributeSlot(const std::string& name) const {
156
+ size_t slot = 0;
157
+ for (const auto& attr : attributes_) {
158
+ if (name == attr.getName()) {
159
+ return slot;
160
+ }
161
+ slot++;
162
+ }
163
+ return std::nullopt;
164
+ }
165
+ size_t getAttributeSlot(const std::string& name) const {
166
+ if (auto r = findAttributeSlot(name)) {
167
+ return *r;
168
+ }
169
+ TORCH_CHECK(
170
+ false,
171
+ repr_str(),
172
+ " does not have an attribute with name '",
173
+ name,
174
+ "'");
175
+ }
176
+
177
+ bool hasAttribute(const std::string& name) const {
178
+ return std::find_if(
179
+ attributes_.cbegin(),
180
+ attributes_.cend(),
181
+ [&](const ClassAttribute& attr) { return attr.getName() == name; }) !=
182
+ attributes_.cend();
183
+ }
184
+
185
+ bool isUnresolvedClassAttribute(const std::string& name) const;
186
+
187
+ at::ArrayRef<TypePtr> containedTypes() const override {
188
+ return attributeTypes_;
189
+ }
190
+
191
+ size_t addAttribute(
192
+ const std::string& name,
193
+ TypePtr type,
194
+ bool is_parameter = false,
195
+ bool is_buffer = false);
196
+
197
+ // [Internal Only] Remove attribute from the ClassType,
198
+ // caller is responsible to make sure the modification is safe:
199
+ // it is unsafe to having existing allocations
200
+ // of this object around anymore, and any code that works on
201
+ // the attribute is now invalid. Only newly created code is
202
+ // valid again.
203
+ void unsafeRemoveAttribute(const std::string& name);
204
+
205
+ // [Internal Only] Change the type of an attribute of the ClassType,
206
+ // The caller is responsible to make sure the modification is safe:
207
+ // it is unsafe to maintain uses of the old type of the attribute,
208
+ // and any code that works on the attribute is now invalid.
209
+ // Only newly created code is valid again.
210
+ void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty);
211
+
212
+ // Add attribute \p NAME if it doesn't exist or verify that it has a
213
+ // compatible type otherwise.
214
+ size_t addOrCheckAttribute(
215
+ const std::string& name,
216
+ TypePtr ty,
217
+ bool is_parameter = false,
218
+ bool is_buffer = false) {
219
+ auto slot_idx = findAttributeSlot(name);
220
+ if (!slot_idx) {
221
+ return addAttribute(name, std::move(ty), is_parameter, is_buffer);
222
+ }
223
+
224
+ TORCH_CHECK(
225
+ is_parameter == this->is_parameter(*slot_idx),
226
+ "Parameter field mismatch for the field '",
227
+ name,
228
+ "'");
229
+ const TypePtr& atype = getAttribute(*slot_idx);
230
+ TORCH_CHECK(
231
+ ty->isSubtypeOf(*atype),
232
+ ty->repr_str(),
233
+ " is not compatible with the type ",
234
+ atype->repr_str(),
235
+ " for the field '",
236
+ name,
237
+ "'");
238
+ return *slot_idx;
239
+ }
240
+
241
+ // Get the property with the given \p name, if it exists on the class.
242
+ std::optional<ClassType::Property> getProperty(const std::string& name);
243
+ // Add a property named \p name with \p getter and \p setter as its getter and setter.
244
+ void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter);
245
+ // Get a list of all properties.
246
+ const std::vector<Property>& properties() const {
247
+ return properties_;
248
+ }
249
+
250
+ bool hasConstant(const std::string& name) const {
251
+ return std::find_if(
252
+ constantNames_.cbegin(),
253
+ constantNames_.cend(),
254
+ [&](const std::string& constant) { return constant == name; }) !=
255
+ constantNames_.cend();
256
+ }
257
+
258
+ size_t addConstant(const std::string& name, const IValue& value);
259
+
260
+ std::optional<size_t> findConstantSlot(const std::string& name) const;
261
+
262
+ size_t getConstantSlot(const std::string& name) const {
263
+ if (auto r = findConstantSlot(name)) {
264
+ return *r;
265
+ }
266
+ TORCH_CHECK(
267
+ false,
268
+ repr_str(),
269
+ " does not have constant field with the name '",
270
+ name,
271
+ "'");
272
+ }
273
+
274
+ const std::string& getConstantName(size_t slot) const;
275
+
276
+ const std::string& doc_string() const {
277
+ return doc_string_;
278
+ }
279
+
280
+ IValue getConstant(const std::string& name) const;
281
+
282
+ IValue getConstant(size_t slot) const;
283
+
284
+ std::optional<IValue> findConstant(const std::string& name) const;
285
+
286
+ size_t numConstants() const;
287
+
288
+ at::ArrayRef<std::string> constantNames() const {
289
+ return constantNames_;
290
+ }
291
+
292
+ at::ArrayRef<IValue> constantValues() const;
293
+
294
+ // [Internal Only] Remove constant from the ClassType
295
+ // caller is responsible to make sure the modification is safe:
296
+ // it is unsafe to having existing allocations
297
+ // of this object around anymore, and any code that works on
298
+ // the attribute is now invalid. Only newly created code is
299
+ // valid again.
300
+ void unsafeRemoveConstant(const std::string& name);
301
+
302
+ TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
303
+ auto ptr = ClassType::create(name(), compilation_unit_, is_module());
304
+ AT_ASSERT(numAttributes() == contained_types.size());
305
+ for(size_t i = 0; i < attributes_.size(); ++i) {
306
+ AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i]));
307
+ ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i]));
308
+ }
309
+ // Copy methods over
310
+ for (const auto& method : methods()) {
311
+ ptr->addMethod(method);
312
+ }
313
+ return ptr;
314
+ }
315
+
316
+ bool is_module() const override {
317
+ return isModule_;
318
+ }
319
+
320
+ const std::vector<ClassAttribute>& getAttributes() const {
321
+ return attributes_;
322
+ }
323
+
324
+ bool is_parameter(size_t slot) const {
325
+ TORCH_INTERNAL_ASSERT(
326
+ is_module(), "asking for parameterSlots of non-Module");
327
+ return attributes_.at(slot).getKind() == AttributeKind::PARAMETER;
328
+ }
329
+
330
+ bool is_buffer(size_t slot) const {
331
+ TORCH_INTERNAL_ASSERT(
332
+ is_module(), "asking for bufferWrittenSlots of non-Module");
333
+ return attributes_.at(slot).getKind() == AttributeKind::BUFFER;
334
+ }
335
+
336
+ void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
337
+ void addForwardHook(torch::jit::Function* hook_ptr);
338
+ torch::jit::Function* findForwardPreHook(const std::string& name) const;
339
+ torch::jit::Function* findForwardHook(const std::string& name) const;
340
+ const std::vector<torch::jit::Function*>& getForwardHooks() const;
341
+ const std::vector<torch::jit::Function*>& getForwardPreHooks() const;
342
+
343
+ void checkForwardPreHookSchema(
344
+ size_t pre_hook_idx,
345
+ const FunctionSchema& pre_hook_schema) const;
346
+ void checkForwardHookSchema(
347
+ size_t hook_idx,
348
+ const FunctionSchema& hook_schema) const;
349
+
350
+ void addMethod(torch::jit::Function* method);
351
+ torch::jit::Function* findMethod(const std::string& name) const;
352
+ torch::jit::Function& getMethod(const std::string& name) const;
353
+ torch::jit::Function* findHook(const std::string& name) const;
354
+ torch::jit::Function& getHook(const std::string& name) const;
355
+ bool hasMethod(const std::string& name) const;
356
+
357
+ torch::jit::Function* findStaticMethod(const std::string& name) const;
358
+ void addStaticMethod(torch::jit::Function* method);
359
+
360
+ // [Internal Only] Remove method from the ClassType
361
+ // caller is responsible to make sure the modification is safe:
362
+ // it is unsafe to having existing allocations
363
+ // of this object around anymore, and any code that works on
364
+ // the attribute is now invalid. Only newly created code is
365
+ // valid again.
366
+ // Note this method is intended for freezing only.
367
+ void unsafeRemoveMethod(const std::string& name);
368
+
369
+ std::shared_ptr<CompilationUnit> compilation_unit();
370
+
371
+ std::shared_ptr<const CompilationUnit> compilation_unit() const;
372
+
373
+ // generate a refined version of this class.
374
+ // It has the same name but the slot Types are subtypes of
375
+ // the original slots. It is only valid to refine a class type in a context
376
+ // where it is know that there are not assignments to the objects slots
377
+ // that would invalidate the refinement.
378
+ // These variants are not registered in the global class table.
379
+ ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
380
+
381
+ bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
382
+
383
+ static const TypeKind Kind = TypeKind::ClassType;
384
+
385
+ private:
386
+ ClassType(
387
+ std::optional<QualifiedName> name,
388
+ std::weak_ptr<CompilationUnit> cu,
389
+ bool is_module = false,
390
+ std::string doc_string = "",
391
+ std::vector<std::string> unresolved_class_attributes = {});
392
+
393
+ std::string annotation_str_impl(
394
+ [[maybe_unused]] const TypePrinter& printer = nullptr) const override {
395
+ // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
396
+ return name()->qualifiedName();
397
+ }
398
+
399
+ void addAttribute(ClassAttribute classAttribute);
400
+ std::string getForwardPreHookErrorMessage(size_t pre_hook_idx) const;
401
+ std::string getForwardHookErrorMessage(size_t hook_idx) const;
402
+
403
+ // Mapping of attribute names -> their type.
404
+ // NOTE: this does not contain methods, which are stored in the module
405
+ // TODO: once modules support arbitrary ivalue attributes, we don't need this
406
+ // anymore.
407
+ // TODO: This is better represented as an OrderedDict, but alas it is not yet
408
+ // available from c10
409
+
410
+ // Mapping of constant names -> their value.
411
+ std::vector<std::string> constantNames_;
412
+ std::vector<IValue> constantValues_;
413
+ // Holds method attributes
414
+ std::weak_ptr<CompilationUnit> compilation_unit_;
415
+
416
+ // Holds all attributes, attribute details are found on ClassAttribute
417
+ std::vector<ClassAttribute> attributes_;
418
+ // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
419
+ // Never fill this without using the appropriate provideNewClassAttribute method
420
+ std::vector<TypePtr> attributeTypes_;
421
+
422
+ // List of methods associated with this class.
423
+ std::vector<torch::jit::Function*> methods_;
424
+ std::vector<torch::jit::Function*> staticmethods_;
425
+
426
+ // List of hooks to be run before/after forward.
427
+ std::vector<torch::jit::Function*> forward_hooks_;
428
+ std::vector<torch::jit::Function*> forward_pre_hooks_;
429
+
430
+ // List of properties exposed by this class.
431
+ std::vector<Property> properties_;
432
+
433
+ bool isModule_ = false;
434
+
435
+ // Doc string of class.
436
+ std::string doc_string_;
437
+
438
+ // For error reporting accesses to class level attributes.
439
+ std::vector<std::string> unresolved_class_attributes_;
440
+ };
441
+
442
+ }
443
+
444
+ #else
445
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
446
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)