koichi12 commited on
Commit
07232d7
·
verified ·
1 Parent(s): 8385308

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/ATen/core/ATenGeneral.h +3 -0
  2. .venv/lib/python3.11/site-packages/torch/include/ATen/core/ATen_fwd.h +46 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/core/Dict_inl.h +209 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/core/List.h +488 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/core/List_inl.h +353 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h +1 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/core/Variadic.h +92 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/core/Vitals.h +91 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/core/aten_interned_strings.h +2264 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/core/blob.h +204 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/core/function.h +114 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/core/functional.h +54 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/core/ivalue_inl.h +2539 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/core/rref_interface.h +40 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/core/stack.h +204 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/core/type_factory.h +108 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_batch_norm_with_update_native.h +26 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_coalesced_native.h +23 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h +23 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h +26 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimI.h +26 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h +24 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_embedding_bag_sparse_backward.h +47 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_max.h +39 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round_cuda_dispatch.h +24 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_lu_with_info.h +30 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h +24 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_native.h +23 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h +24 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_get_jagged_dummy_ops.h +28 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_pack_padded_sequence_backward_native.h +21 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_coo_tensor_with_dims_ops.h +39 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_semi_structured_mm.h +30 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_cuda_dispatch.h +23 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_check_tensor_ops.h +28 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h +23 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/add_ops.h +83 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/arccosh_ops.h +50 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/as_strided_scatter_ops.h +39 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_native.h +23 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h +25 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/binomial_cuda_dispatch.h +23 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/clamp_native.h +27 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/conv3d_native.h +22 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h +26 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/dropout_compositeimplicitautograd_dispatch.h +24 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/elu_backward_meta.h +27 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/erf_native.h +29 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1.h +44 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/core/ATenGeneral.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/core/ATen_fwd.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/core/QScheme.h>
3
+
4
+ // Forward declarations of core ATen types used in dispatch functions
5
+ namespace c10 {
6
+
7
+ template<typename T>
8
+ class List;
9
+ template<typename T>
10
+ class IListRef;
11
+ class Stream;
12
+ class Scalar;
13
+ class SymInt;
14
+ class SymIntList;
15
+ struct Storage;
16
+ struct TensorOptions;
17
+ template <typename T>
18
+ class ArrayRef;
19
+ template <typename T>
20
+ class OptionalArrayRef;
21
+
22
+ } // namespace c10
23
+
24
+ namespace at {
25
+
26
+ class Tensor;
27
+ class OptionalTensorRef;
28
+ struct Dimname;
29
+ struct Generator;
30
+ using TensorList = c10::ArrayRef<Tensor>;
31
+ using ITensorListRef = c10::IListRef<Tensor>;
32
+ using IOptTensorListRef = c10::IListRef<OptionalTensorRef>;
33
+ using DimnameList = c10::ArrayRef<Dimname>;
34
+ using IntArrayRef = c10::ArrayRef<int64_t>;
35
+ using OptionalIntArrayRef = c10::OptionalArrayRef<int64_t>;
36
+ using OptionalSymIntArrayRef = c10::OptionalArrayRef<c10::SymInt>;
37
+
38
+ using c10::Stream;
39
+ using c10::Storage;
40
+ using c10::QScheme;
41
+ using c10::Scalar;
42
+ using c10::SymInt;
43
+ using c10::SymIntList;
44
+ using c10::TensorOptions;
45
+
46
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/core/Dict_inl.h ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue.h>
4
+ #include <c10/util/hash.h>
5
+
6
+ namespace c10 {
7
+ namespace detail {
8
+ inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
9
+ if (lhs.isTensor() && rhs.isTensor()) {
10
+ // for tensors, we compare only by identity (following how it's done in Python).
11
+ return lhs.is(rhs);
12
+ }
13
+ // Otherwise, we first compare by identity for efficiency, then by value (see:
14
+ // [container equality])
15
+ return _fastEqualsForContainer(lhs, rhs);
16
+ }
17
+ }
18
+
19
+ template<class T> decltype(auto) getTypePtr();
20
+ std::string toString(const Type& type);
21
+
22
+ namespace impl {
23
+
24
+ template<class Key, class Value>
25
+ Dict<Key, Value> toTypedDict(GenericDict dict) {
26
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
27
+ TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
28
+
29
+ return Dict<Key, Value>(std::move(dict.impl_));
30
+ }
31
+
32
+ template<class Key, class Value>
33
+ GenericDict toGenericDict(Dict<Key, Value> dict) {
34
+ return GenericDict(std::move(dict.impl_));
35
+ }
36
+ }
37
+
38
+ namespace detail {
39
+
40
+ inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
41
+ if (ivalue.isInt()) {
42
+ return std::hash<int64_t>()(ivalue.toInt());
43
+ } else if (ivalue.isString()) {
44
+ return std::hash<c10::string_view>()(ivalue.toStringView());
45
+ } else if (ivalue.isDouble()) {
46
+ return std::hash<double>()(ivalue.toDouble());
47
+ } else if (ivalue.isComplexDouble()) {
48
+ return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
49
+ } else if (ivalue.isBool()) {
50
+ return std::hash<bool>()(ivalue.toBool());
51
+ } else if (ivalue.isTensor()) {
52
+ return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
53
+ } else if (ivalue.isDevice()) {
54
+ return std::hash<Device>()(ivalue.toDevice());
55
+ } else {
56
+ throw std::runtime_error(
57
+ "Can't hash IValues with tag '" + ivalue.tagKind() + "'");
58
+ }
59
+ }
60
+
61
+ inline intrusive_ptr<DictImpl> DictImpl::copy() const {
62
+ return make_intrusive<DictImpl>(dict, elementTypes);
63
+ }
64
+
65
+ }
66
+
67
+ template<class Key, class Value>
68
+ Dict<Key, Value>::Dict()
69
+ :Dict(make_intrusive<detail::DictImpl>(
70
+ detail::DictImpl::dict_map_type(),
71
+ detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
72
+ static_assert(!std::is_same<Key, IValue>::value, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
73
+ static_assert(!std::is_same<Value, IValue>::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
74
+ }
75
+
76
+ template<class Key, class Value>
77
+ Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
78
+ : Dict(make_intrusive<detail::DictImpl>(
79
+ detail::DictImpl::dict_map_type(),
80
+ detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
81
+ static_assert(std::is_same<Key, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
82
+ static_assert(std::is_same<Value, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
83
+ }
84
+
85
+ template<class Key, class Value>
86
+ Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
87
+
88
+ template<class Key, class Value>
89
+ Dict<Key, Value> Dict<Key, Value>::copy() const {
90
+ return Dict<Key, Value>(impl_->copy());
91
+ }
92
+
93
+ template<class Key, class Value>
94
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
95
+ return iterator{impl_->dict.begin()};
96
+ }
97
+
98
+ template<class Key, class Value>
99
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
100
+ return iterator{impl_->dict.end()};
101
+ }
102
+
103
+ template<class Key, class Value>
104
+ bool Dict<Key, Value>::empty() const {
105
+ return impl_->dict.empty();
106
+ }
107
+
108
+ template<class Key, class Value>
109
+ typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
110
+ return impl_->dict.size();
111
+ }
112
+
113
+ template<class Key, class Value>
114
+ void Dict<Key, Value>::clear() const {
115
+ impl_->dict.clear();
116
+ }
117
+
118
+ template<class Key, class Value>
119
+ template<class Key_, class Value_>
120
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
121
+ static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
122
+ static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
123
+ auto inserted = impl_->dict.emplace(
124
+ Key(std::forward<Key_>(key)),
125
+ Value(std::forward<Value_>(value)));
126
+ return {iterator{inserted.first}, inserted.second};
127
+ }
128
+
129
+ template<class Key, class Value>
130
+ template<class Key_, class Value_>
131
+ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
132
+ static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
133
+ static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
134
+ auto inserted = impl_->dict.insert_or_assign(
135
+ Key(std::forward<Key_>(key)),
136
+ Value(std::forward<Value_>(value)));
137
+ return {iterator{inserted.first}, inserted.second};
138
+ }
139
+
140
+ template<class Key, class Value>
141
+ void Dict<Key, Value>::erase(iterator iter) const {
142
+ impl_->dict.erase(iter.entryRef_.iterator_);
143
+ }
144
+
145
+ template<class Key, class Value>
146
+ C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) const {
147
+ return impl_->dict.erase(key);
148
+ }
149
+
150
+ template<class Key, class Value>
151
+ Value Dict<Key, Value>::at(const Key& key) const {
152
+ return impl_->dict.at(key).template to<Value>();
153
+ }
154
+
155
+ template<class Key, class Value>
156
+ typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
157
+ return iterator{impl_->dict.find(key)};
158
+ }
159
+
160
+ template<class Key, class Value>
161
+ bool Dict<Key, Value>::contains(const Key& key) const {
162
+ return end() != find(key);
163
+ }
164
+
165
+ template<class Key, class Value>
166
+ void Dict<Key, Value>::reserve(size_type count) const {
167
+ impl_->dict.reserve(count);
168
+ }
169
+
170
+ template<class Key, class Value>
171
+ TypePtr Dict<Key, Value>::keyType() const {
172
+ return impl_->elementTypes.keyType;
173
+ }
174
+
175
+ template<class Key, class Value>
176
+ TypePtr Dict<Key, Value>::valueType() const {
177
+ return impl_->elementTypes.valueType;
178
+ }
179
+ template <class Key, class Value>
180
+ void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
181
+ impl_->elementTypes.keyType = std::move(t);
182
+ }
183
+
184
+ template <class Key, class Value>
185
+ void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
186
+ impl_->elementTypes.valueType = std::move(t);
187
+ }
188
+
189
+ template <class Key_, class Value_>
190
+ bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
191
+ // Dicts with the same identity trivially compare equal.
192
+ if (lhs.impl_ == rhs.impl_) {
193
+ return true;
194
+ }
195
+
196
+ // Otherwise compare the values
197
+ return *lhs.impl_ == *rhs.impl_;
198
+ }
199
+
200
+ template <class Key_, class Value_>
201
+ bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
202
+ return !(lhs == rhs);
203
+ }
204
+
205
+ template <class Key, class Value>
206
+ bool Dict<Key, Value>::is(const Dict& rhs) const {
207
+ return this->impl_ == rhs.impl_;
208
+ }
209
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/core/List.h ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue_to.h>
4
+ #include <ATen/core/jit_type_base.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/macros/Export.h>
7
+ #include <c10/util/TypeTraits.h>
8
+ #include <c10/util/TypeList.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <c10/util/ArrayRef.h>
11
+ #include <optional>
12
+ #include <vector>
13
+
14
+ namespace at {
15
+ class Tensor;
16
+ }
17
+ namespace c10 {
18
+ struct IValue;
19
+ template<class T> class List;
20
+ struct Type;
21
+
22
+ namespace detail {
23
+
24
+ struct ListImpl final : public c10::intrusive_ptr_target {
25
+ using list_type = std::vector<IValue>;
26
+
27
+ explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
28
+
29
+ list_type list;
30
+
31
+ TypePtr elementType;
32
+
33
+ intrusive_ptr<ListImpl> copy() const {
34
+ return make_intrusive<ListImpl>(list, elementType);
35
+ }
36
+ friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
37
+ };
38
+ }
39
+
40
+ namespace impl {
41
+
42
+ template<class T, class Iterator> class ListIterator;
43
+
44
+ template<class T, class Iterator> class ListElementReference;
45
+
46
+ template<class T, class Iterator>
47
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
48
+
49
+ template<class T, class Iterator>
50
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
51
+
52
+ template<class T, class Iterator>
53
+ bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
54
+
55
+ template<class T>
56
+ struct ListElementConstReferenceTraits {
57
+ // In the general case, we use IValue::to().
58
+ using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
59
+ };
60
+
61
+ // There is no to() overload for std::optional<std::string>.
62
+ template<>
63
+ struct ListElementConstReferenceTraits<std::optional<std::string>> {
64
+ using const_reference = std::optional<std::reference_wrapper<const std::string>>;
65
+ };
66
+
67
+ template<class T, class Iterator>
68
+ class ListElementReference final {
69
+ public:
70
+ operator std::conditional_t<
71
+ std::is_reference_v<typename c10::detail::
72
+ ivalue_to_const_ref_overload_return<T>::type>,
73
+ const T&,
74
+ T>() const;
75
+
76
+ ListElementReference& operator=(T&& new_value) &&;
77
+
78
+ ListElementReference& operator=(const T& new_value) &&;
79
+
80
+ // assigning another ref to this assigns the underlying value
81
+ ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
82
+
83
+ const IValue& get() const& {
84
+ return *iterator_;
85
+ }
86
+
87
+ friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
88
+
89
+ ListElementReference(const ListElementReference&) = delete;
90
+ ListElementReference& operator=(const ListElementReference&) = delete;
91
+
92
+ private:
93
+ ListElementReference(Iterator iter)
94
+ : iterator_(iter) {}
95
+
96
+ // allow moving, but only our friends (i.e. the List class) can move us
97
+ ListElementReference(ListElementReference&&) noexcept = default;
98
+ ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
99
+ iterator_ = std::move(rhs.iterator_);
100
+ return *this;
101
+ }
102
+
103
+ friend class List<T>;
104
+ friend class ListIterator<T, Iterator>;
105
+
106
+ Iterator iterator_;
107
+ };
108
+
109
+ // this wraps vector::iterator to make sure user code can't rely
110
+ // on it being the type of the underlying vector.
111
+ template <class T, class Iterator>
112
+ class ListIterator final {
113
+ public:
114
+ // C++17 friendly std::iterator implementation
115
+ using iterator_category = std::random_access_iterator_tag;
116
+ using value_type = T;
117
+ using difference_type = std::ptrdiff_t;
118
+ using pointer = T*;
119
+ using reference = ListElementReference<T, Iterator>;
120
+
121
+ explicit ListIterator() = default;
122
+ ~ListIterator() = default;
123
+
124
+ ListIterator(const ListIterator&) = default;
125
+ ListIterator(ListIterator&&) noexcept = default;
126
+ ListIterator& operator=(const ListIterator&) = default;
127
+ ListIterator& operator=(ListIterator&&) noexcept = default;
128
+
129
+ ListIterator& operator++() {
130
+ ++iterator_;
131
+ return *this;
132
+ }
133
+
134
+ ListIterator operator++(int) {
135
+ ListIterator copy(*this);
136
+ ++*this;
137
+ return copy;
138
+ }
139
+
140
+ ListIterator& operator--() {
141
+ --iterator_;
142
+ return *this;
143
+ }
144
+
145
+ ListIterator operator--(int) {
146
+ ListIterator copy(*this);
147
+ --*this;
148
+ return copy;
149
+ }
150
+
151
+ ListIterator& operator+=(typename List<T>::size_type offset) {
152
+ iterator_ += offset;
153
+ return *this;
154
+ }
155
+
156
+ ListIterator& operator-=(typename List<T>::size_type offset) {
157
+ iterator_ -= offset;
158
+ return *this;
159
+ }
160
+
161
+ ListIterator operator+(typename List<T>::size_type offset) const {
162
+ return ListIterator{iterator_ + offset};
163
+ }
164
+
165
+ ListIterator operator-(typename List<T>::size_type offset) const {
166
+ return ListIterator{iterator_ - offset};
167
+ }
168
+
169
+ friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
170
+ return lhs.iterator_ - rhs.iterator_;
171
+ }
172
+
173
+ ListElementReference<T, Iterator> operator*() const {
174
+ return {iterator_};
175
+ }
176
+
177
+ ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
178
+ return {iterator_ + offset};
179
+ }
180
+
181
+ private:
182
+ explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
183
+
184
+ Iterator iterator_;
185
+
186
+ friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
187
+ return lhs.iterator_ == rhs.iterator_;
188
+ }
189
+
190
+ friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
191
+ return !(lhs == rhs);
192
+ }
193
+
194
+ friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
195
+ return lhs.iterator_ < rhs.iterator_;
196
+ }
197
+
198
+ friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
199
+ return lhs.iterator_ <= rhs.iterator_;
200
+ }
201
+
202
+ friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
203
+ return lhs.iterator_ > rhs.iterator_;
204
+ }
205
+
206
+ friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
207
+ return lhs.iterator_ >= rhs.iterator_;
208
+ }
209
+
210
+ friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
211
+ friend class List<T>;
212
+ };
213
+
214
+ template<class T> List<T> toTypedList(List<IValue> list);
215
+ template<class T> List<IValue> toList(List<T>&& list);
216
+ template<class T> List<IValue> toList(const List<T>& list);
217
+ const IValue* ptr_to_first_element(const List<IValue>& list);
218
+ }
219
+
220
+ /**
221
+ * An object of this class stores a list of values of type T.
222
+ *
223
+ * This is a pointer type. After a copy, both Lists
224
+ * will share the same storage:
225
+ *
226
+ * > List<int> a;
227
+ * > List<int> b = a;
228
+ * > b.push_back("three");
229
+ * > ASSERT("three" == a.get(0));
230
+ *
231
+ * We use this class in the PyTorch kernel API instead of
232
+ * std::vector<T>, because that allows us to do optimizations
233
+ * and switch out the underlying list implementation without
234
+ * breaking backwards compatibility for the kernel API.
235
+ */
236
+ template<class T>
237
+ class List final {
238
+ private:
239
+ // This is an intrusive_ptr because List is a pointer type.
240
+ // Invariant: This will never be a nullptr, there will always be a valid
241
+ // ListImpl.
242
+ c10::intrusive_ptr<c10::detail::ListImpl> impl_;
243
+
244
+ using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
245
+ using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
246
+
247
+ public:
248
+ using value_type = T;
249
+ using size_type = typename c10::detail::ListImpl::list_type::size_type;
250
+ using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
251
+ using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
252
+ using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
253
+
254
+ /**
255
+ * Constructs an empty list.
256
+ */
257
+ explicit List();
258
+
259
+ /**
260
+ * Constructs a list with some initial values.
261
+ * Example:
262
+ * List<int> a({2, 3, 4});
263
+ */
264
+ List(std::initializer_list<T> initial_values);
265
+ explicit List(ArrayRef<T> initial_values);
266
+
267
+ /**
268
+ * Create a generic list with runtime type information.
269
+ * This only works for c10::impl::GenericList and is not part of the public API
270
+ * but only supposed to be used internally by PyTorch.
271
+ */
272
+ explicit List(TypePtr elementType);
273
+
274
+ List(const List&) = default;
275
+ List& operator=(const List&) = default;
276
+
277
+ /**
278
+ * Create a new List pointing to a deep copy of the same data.
279
+ * The List returned is a new list with separate storage.
280
+ * Changes in it are not reflected in the original list or vice versa.
281
+ */
282
+ List copy() const;
283
+
284
+ /**
285
+ * Returns the element at specified location pos, with bounds checking.
286
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
287
+ */
288
+ internal_const_reference_type get(size_type pos) const;
289
+
290
+ /**
291
+ * Moves out the element at the specified location pos and returns it, with bounds checking.
292
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
293
+ * The list contains an invalid element at position pos afterwards. Any operations
294
+ * on it before re-setting it are invalid.
295
+ */
296
+ value_type extract(size_type pos) const;
297
+
298
+ /**
299
+ * Returns a reference to the element at specified location pos, with bounds checking.
300
+ * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
301
+ *
302
+ * You cannot store the reference, but you can read it and assign new values to it:
303
+ *
304
+ * List<int64_t> list = ...;
305
+ * list[2] = 5;
306
+ * int64_t v = list[1];
307
+ */
308
+ internal_const_reference_type operator[](size_type pos) const;
309
+
310
+ internal_reference_type operator[](size_type pos);
311
+
312
+ /**
313
+ * Assigns a new value to the element at location pos.
314
+ */
315
+ void set(size_type pos, const value_type& value) const;
316
+
317
+ /**
318
+ * Assigns a new value to the element at location pos.
319
+ */
320
+ void set(size_type pos, value_type&& value) const;
321
+
322
+ /**
323
+ * Returns an iterator to the first element of the container.
324
+ * If the container is empty, the returned iterator will be equal to end().
325
+ */
326
+ iterator begin() const;
327
+
328
+ /**
329
+ * Returns an iterator to the element following the last element of the container.
330
+ * This element acts as a placeholder; attempting to access it results in undefined behavior.
331
+ */
332
+ iterator end() const;
333
+
334
+ /**
335
+ * Checks if the container has no elements.
336
+ */
337
+ bool empty() const;
338
+
339
+ /**
340
+ * Returns the number of elements in the container
341
+ */
342
+ size_type size() const;
343
+
344
+ /**
345
+ * Increase the capacity of the vector to a value that's greater or equal to new_cap.
346
+ */
347
+ void reserve(size_type new_cap) const;
348
+
349
+ /**
350
+ * Erases all elements from the container. After this call, size() returns zero.
351
+ * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
352
+ */
353
+ void clear() const;
354
+
355
+ /**
356
+ * Inserts value before pos.
357
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
358
+ */
359
+ iterator insert(iterator pos, const T& value) const;
360
+
361
+ /**
362
+ * Inserts value before pos.
363
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
364
+ */
365
+ iterator insert(iterator pos, T&& value) const;
366
+
367
+ /**
368
+ * Inserts a new element into the container directly before pos.
369
+ * The new element is constructed with the given arguments.
370
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
371
+ */
372
+ template<class... Args>
373
+ iterator emplace(iterator pos, Args&&... value) const;
374
+
375
+ /**
376
+ * Appends the given element value to the end of the container.
377
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
378
+ */
379
+ void push_back(const T& value) const;
380
+
381
+ /**
382
+ * Appends the given element value to the end of the container.
383
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
384
+ */
385
+ void push_back(T&& value) const;
386
+
387
+ /**
388
+ * Appends the given list to the end of the container. Uses at most one memory allocation.
389
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
390
+ */
391
+ void append(List<T> lst) const;
392
+
393
+ /**
394
+ * Appends the given element value to the end of the container.
395
+ * The new element is constructed with the given arguments.
396
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
397
+ */
398
+ template<class... Args>
399
+ void emplace_back(Args&&... args) const;
400
+
401
+ /**
402
+ * Removes the element at pos.
403
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
404
+ */
405
+ iterator erase(iterator pos) const;
406
+
407
+ /**
408
+ * Removes the elements in the range [first, last).
409
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
410
+ */
411
+ iterator erase(iterator first, iterator last) const;
412
+
413
+ /**
414
+ * Removes the last element of the container.
415
+ * Calling pop_back on an empty container is undefined.
416
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
417
+ */
418
+ void pop_back() const;
419
+
420
+ /**
421
+ * Resizes the container to contain count elements.
422
+ * If the current size is less than count, additional default-inserted elements are appended.
423
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
424
+ */
425
+ void resize(size_type count) const;
426
+
427
+ /**
428
+ * Resizes the container to contain count elements.
429
+ * If the current size is less than count, additional copies of value are appended.
430
+ * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
431
+ */
432
+ void resize(size_type count, const T& value) const;
433
+
434
+ /**
435
+ * Value equality comparison. This function implements Python-like semantics for
436
+ * equality: two lists with the same identity (e.g. same pointer) trivially
437
+ * compare equal, otherwise each element is compared for equality.
438
+ */
439
+ template <class T_>
440
+ friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
441
+
442
+ template <class T_>
443
+ friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
444
+
445
+ /**
446
+ * Identity comparison. Returns true if and only if `rhs` represents the same
447
+ * List object as `this`.
448
+ */
449
+ bool is(const List<T>& rhs) const;
450
+
451
+ std::vector<T> vec() const;
452
+
453
+ /**
454
+ * Returns the number of Lists currently pointing to this same list.
455
+ * If this is the only instance pointing to this list, returns 1.
456
+ */
457
+ // TODO Test use_count
458
+ size_t use_count() const;
459
+
460
+ TypePtr elementType() const;
461
+
462
+ // See [unsafe set type] for why this exists.
463
+ void unsafeSetElementType(TypePtr t);
464
+
465
+ private:
466
+ explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
467
+ explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
468
+ friend struct IValue;
469
+ template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
470
+ template<class T_> friend List<IValue> impl::toList(List<T_>&&);
471
+ template<class T_> friend List<IValue> impl::toList(const List<T_>&);
472
+ friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
473
+ };
474
+
475
+ namespace impl {
476
+ // GenericList is how IValue stores lists. It is, however, not part of the
477
+ // public API. Kernels should use Lists with concrete types instead
478
+ // (maybe except for some internal prim ops).
479
+ using GenericList = List<IValue>;
480
+
481
+ }
482
+ }
483
+
484
+ namespace torch {
485
+ template<class T> using List = c10::List<T>;
486
+ }
487
+
488
+ #include <ATen/core/List_inl.h> // IWYU pragma: keep
.venv/lib/python3.11/site-packages/torch/include/ATen/core/List_inl.h ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/jit_type_base.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ namespace c10 {
7
+
8
+ template<class T> decltype(auto) getTypePtr();
9
+ std::string toString(const Type& type);
10
+
11
+ template<class T>
12
+ List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
13
+ : impl_(std::move(elements)) {}
14
+
15
+ template<class T>
16
+ List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
17
+ : impl_(elements) {}
18
+
19
+ template<class T>
20
+ List<T>::List()
21
+ : List(make_intrusive<c10::detail::ListImpl>(
22
+ typename c10::detail::ListImpl::list_type(),
23
+ getTypePtr<T>())) {
24
+ static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
25
+ }
26
+
27
+ template<class T>
28
+ List<T>::List(ArrayRef<T> values)
29
+ : List(make_intrusive<c10::detail::ListImpl>(
30
+ typename c10::detail::ListImpl::list_type(),
31
+ getTypePtr<T>())) {
32
+ static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
33
+ impl_->list.reserve(values.size());
34
+ for (const T& element : values) {
35
+ impl_->list.push_back(element);
36
+ }
37
+ }
38
+
39
+ template<class T>
40
+ List<T>::List(std::initializer_list<T> initial_values)
41
+ : List(ArrayRef<T>(initial_values)) {
42
+ static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
43
+ }
44
+
45
+ template<class T>
46
+ List<T>::List(TypePtr elementType)
47
+ : List(make_intrusive<c10::detail::ListImpl>(
48
+ typename c10::detail::ListImpl::list_type(),
49
+ std::move(elementType))) {
50
+ static_assert(std::is_same<T, IValue>::value || std::is_same<T, c10::intrusive_ptr<ivalue::Future>>::value,
51
+ "This constructor is only valid for c10::impl::GenericList or List<Future>.");
52
+ }
53
+
54
+ namespace impl {
55
+ template<class T>
56
+ List<T> toTypedList(impl::GenericList list) {
57
+ // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
58
+ // because upcasting would allow people to add types into the new list that would break the old list.
59
+ // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
60
+ // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
61
+ // without having to copy it. This is also used to provide backwards compatibility with some old models
62
+ // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
63
+ // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
64
+ // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
65
+ TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66
+ || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
67
+ , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
68
+ return List<T>(std::move(list.impl_));
69
+ }
70
+
71
+ template<class T>
72
+ impl::GenericList toList(List<T>&& list) {
73
+ return GenericList(std::move(list.impl_));
74
+ }
75
+ template<class T>
76
+ impl::GenericList toList(const List<T>& list) {
77
+ return GenericList(list.impl_);
78
+ }
79
+ }
80
+
81
+ template<class T>
82
+ List<T> List<T>::copy() const {
83
+ return List<T>(impl_->copy());
84
+ }
85
+
86
+ namespace detail {
87
+ template<class T>
88
+ T list_element_to(T element) {
89
+ return element;
90
+ }
91
+ template<class T>
92
+ T list_element_to(const IValue& element) {
93
+ return element.template to<T>();
94
+ }
95
+ template<class T>
96
+ T list_element_to(IValue&& element) {
97
+ return std::move(element).template to<T>();
98
+ }
99
+ template<class T>
100
+ struct ListElementFrom {
101
+ static IValue from(const T& element) {
102
+ return element;
103
+ }
104
+ static IValue from(T&& element) {
105
+ return std::move(element);
106
+ }
107
+ };
108
+ template<>
109
+ struct ListElementFrom<IValue> {
110
+ static const IValue& from(const IValue& element) {
111
+ return element;
112
+ }
113
+ static IValue&& from(IValue&& element) {
114
+ return std::move(element);
115
+ }
116
+ };
117
+ }
118
+
119
+ namespace impl {
120
+
121
+ template <class T, class Iterator>
122
+ ListElementReference<T, Iterator>::operator std::conditional_t<
123
+ std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
124
+ T>::type>,
125
+ const T&,
126
+ T>() const {
127
+ return iterator_->template to<T>();
128
+ }
129
+
130
+ template<class T, class Iterator>
131
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
132
+ *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
133
+ return *this;
134
+ }
135
+
136
+ template<class T, class Iterator>
137
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
138
+ *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
139
+ return *this;
140
+ }
141
+
142
+ template<class T, class Iterator>
143
+ ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
144
+ *iterator_ = *rhs.iterator_;
145
+ return *this;
146
+ }
147
+
148
+ template<class T, class Iterator>
149
+ void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
150
+ std::swap(*lhs.iterator_, *rhs.iterator_);
151
+ }
152
+
153
+ template<class T, class Iterator>
154
+ bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
155
+ const T& lhs_tmp = lhs;
156
+ return lhs_tmp == rhs;
157
+ }
158
+
159
+ template<class T, class Iterator>
160
+ inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
161
+ return rhs == lhs;
162
+ }
163
+
164
+ template<class T>
165
+ inline typename ListElementConstReferenceTraits<T>::const_reference
166
+ list_element_to_const_ref(const IValue& element) {
167
+ return element.template to<T>();
168
+ }
169
+
170
+ template<>
171
+ inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
172
+ list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
173
+ return element.toOptionalStringRef();
174
+ }
175
+
176
+ } // namespace impl
177
+
178
+ template<class T>
179
+ void List<T>::set(size_type pos, const value_type& value) const {
180
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
181
+ }
182
+
183
+ template<class T>
184
+ void List<T>::set(size_type pos, value_type&& value) const {
185
+ impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
186
+ }
187
+
188
+ template<class T>
189
+ typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
190
+ return operator[](pos);
191
+ }
192
+
193
+ template<class T>
194
+ typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
195
+ return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
196
+ }
197
+
198
+ template<class T>
199
+ typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
200
+ static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
201
+ return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
202
+ }
203
+
204
+ template<class T>
205
+ typename List<T>::value_type List<T>::extract(size_type pos) const {
206
+ auto& elem = impl_->list.at(pos);
207
+ auto result = c10::detail::list_element_to<T>(std::move(elem));
208
+ // Reset the list element to a T() instead of None to keep it correctly typed
209
+ elem = c10::detail::ListElementFrom<T>::from(T{});
210
+ return result;
211
+ }
212
+
213
+ template<class T>
214
+ typename List<T>::iterator List<T>::begin() const {
215
+ return iterator(impl_->list.begin());
216
+ }
217
+
218
+ template<class T>
219
+ typename List<T>::iterator List<T>::end() const {
220
+ return iterator(impl_->list.end());
221
+ }
222
+
223
+ template<class T>
224
+ bool List<T>::empty() const {
225
+ return impl_->list.empty();
226
+ }
227
+
228
+ template<class T>
229
+ typename List<T>::size_type List<T>::size() const {
230
+ return impl_->list.size();
231
+ }
232
+
233
+ template<class T>
234
+ void List<T>::reserve(size_type new_cap) const {
235
+ impl_->list.reserve(new_cap);
236
+ }
237
+
238
+ template<class T>
239
+ void List<T>::clear() const {
240
+ impl_->list.clear();
241
+ }
242
+
243
+ template<class T>
244
+ typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
245
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
246
+ }
247
+
248
+ template<class T>
249
+ typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
250
+ return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
251
+ }
252
+
253
+ template<class T>
254
+ template<class... Args>
255
+ typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
256
+ // TODO Use list_element_from?
257
+ return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
258
+ }
259
+
260
+ template<class T>
261
+ void List<T>::push_back(const T& value) const {
262
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
263
+ }
264
+
265
+ template<class T>
266
+ void List<T>::push_back(T&& value) const {
267
+ impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
268
+ }
269
+
270
+ template<class T>
271
+ void List<T>::append(List<T> b) const {
272
+ if (b.use_count() == 1) {
273
+ impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
274
+ } else {
275
+ impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
276
+ }
277
+ }
278
+
279
+ template<class T>
280
+ template<class... Args>
281
+ void List<T>::emplace_back(Args&&... args) const {
282
+ // TODO Use list_element_from?
283
+ impl_->list.push_back(T(std::forward<Args>(args)...));
284
+ }
285
+
286
+ template<class T>
287
+ typename List<T>::iterator List<T>::erase(iterator pos) const {
288
+ return iterator { impl_->list.erase(pos.iterator_) };
289
+ }
290
+
291
+ template<class T>
292
+ typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
293
+ return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
294
+ }
295
+
296
+ template<class T>
297
+ void List<T>::pop_back() const {
298
+ impl_->list.pop_back();
299
+ }
300
+
301
+ template<class T>
302
+ void List<T>::resize(size_type count) const {
303
+ impl_->list.resize(count, T{});
304
+ }
305
+
306
+ template<class T>
307
+ void List<T>::resize(size_type count, const T& value) const {
308
+ impl_->list.resize(count, value);
309
+ }
310
+
311
+ template<class T>
312
+ bool operator==(const List<T>& lhs, const List<T>& rhs) {
313
+ // Lists with the same identity trivially compare equal.
314
+ if (lhs.impl_ == rhs.impl_) {
315
+ return true;
316
+ }
317
+
318
+ // Otherwise, just compare values directly.
319
+ return *lhs.impl_ == *rhs.impl_;
320
+ }
321
+
322
+ template<class T>
323
+ bool operator!=(const List<T>& lhs, const List<T>& rhs) {
324
+ return !(lhs == rhs);
325
+ }
326
+
327
+ template<class T>
328
+ bool List<T>::is(const List<T>& rhs) const {
329
+ return this->impl_ == rhs.impl_;
330
+ }
331
+
332
+ template<class T>
333
+ std::vector<T> List<T>::vec() const {
334
+ std::vector<T> result(begin(), end());
335
+ return result;
336
+ }
337
+
338
+ template<class T>
339
+ size_t List<T>::use_count() const {
340
+ return impl_.use_count();
341
+ }
342
+
343
+ template <class T>
344
+ TypePtr List<T>::elementType() const {
345
+ return impl_->elementType;
346
+ }
347
+
348
+ template <class T>
349
+ void List<T>::unsafeSetElementType(TypePtr t) {
350
+ impl_->elementType = std::move(t);
351
+ }
352
+
353
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <c10/core/UndefinedTensorImpl.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/core/Variadic.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <utility>
4
+
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <ATen/core/List.h>
7
+
8
+ namespace at {
9
+
10
+ // This class allows you to write variadic functions which
11
+ // call a (possibly overloaded) function on each argument,
12
+ // in order. This is most commonly used in autogenerated code,
13
+ // where it is convenient to have a function that can uniformly
14
+ // take arguments of different types. If your arguments
15
+ // are homogenous consider using a std::initializer_list instead.
16
+ //
17
+ // For examples of this in use, see torch/csrc/utils/variadic.h
18
+ template <typename F>
19
+ struct IterArgs {
20
+ template <typename... Args>
21
+ inline F& apply() {
22
+ return self();
23
+ }
24
+
25
+ // NB: Use perfect forwarding here, otherwise we'll make value
26
+ // copies of all arguments!
27
+ template <typename T, typename... Args>
28
+ inline F& apply(T&& arg, Args&&... args) {
29
+ self()(std::forward<T>(arg));
30
+ if (self().short_circuit()) {
31
+ return self();
32
+ } else {
33
+ return apply(std::forward<Args>(args)...);
34
+ }
35
+ }
36
+
37
+ // Here are some handy overloads which provide sensible
38
+ // defaults for container-like structures that one might
39
+ // be interested in recursing into. You can enable them
40
+ // by adding:
41
+ //
42
+ // using IterArgs<YourStructName>::operator()
43
+ //
44
+ // to your struct. These are not enabled by default because
45
+ // you may be able to process these structures more efficiently
46
+ // than handling them one-by-one.
47
+
48
+ template <typename T>
49
+ void operator()(c10::IListRef<T> args) {
50
+ for (const auto& arg : args) {
51
+ self()(arg);
52
+ if (self().short_circuit())
53
+ return;
54
+ }
55
+ }
56
+
57
+ template <typename T>
58
+ void operator()(at::ArrayRef<T> args) {
59
+ for (const auto& arg : args) {
60
+ self()(arg);
61
+ if (self().short_circuit())
62
+ return;
63
+ }
64
+ }
65
+
66
+ template <typename T>
67
+ void operator()(const torch::List<T>& args) {
68
+ for (const auto& arg : args) {
69
+ self()(arg);
70
+ if (self().short_circuit())
71
+ return;
72
+ }
73
+ }
74
+
75
+ // NB: we need to specify std::vector manually as C++ won't
76
+ // do an implicit conversion to make a template deduction go through.
77
+ template <typename T>
78
+ void operator()(const std::vector<T>& args) {
79
+ self()(at::ArrayRef<T>{args});
80
+ }
81
+
82
+ constexpr bool short_circuit() const {
83
+ return false;
84
+ }
85
+
86
+ private:
87
+ inline F& self() {
88
+ return *static_cast<F*>(this);
89
+ }
90
+ };
91
+
92
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/ATen/core/Vitals.h ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ostream>
3
+ #include <sstream>
4
+ #include <unordered_map>
5
+
6
+ #include <c10/core/impl/LocalDispatchKeySet.h>
7
+
8
+ namespace at::vitals {
9
+
10
+ TORCH_API bool torchVitalEnabled();
11
+
12
+ struct TORCH_API TorchVitalAttr {
13
+ // always initialized to empty
14
+ std::string value = "";
15
+ template <typename T>
16
+ TorchVitalAttr& operator<<(const T& t) {
17
+ if (torchVitalEnabled()) {
18
+ std::stringstream ss;
19
+ ss << t;
20
+ value += ss.str();
21
+ }
22
+ return *this;
23
+ }
24
+
25
+ template <typename T>
26
+ void write(const T& t, bool force) {
27
+ if (force || torchVitalEnabled()) {
28
+ std::stringstream ss;
29
+ ss << t;
30
+ value = ss.str();
31
+ }
32
+ }
33
+ };
34
+
35
+ struct TORCH_API TorchVital {
36
+ std::string name;
37
+ std::unordered_map<std::string, TorchVitalAttr> attrs;
38
+
39
+ explicit TorchVital(std::string n) : name(std::move(n)) {}
40
+ TorchVital(const TorchVital&) = default;
41
+ TorchVital(TorchVital&&) = default;
42
+ TorchVital() = delete;
43
+
44
+ TorchVitalAttr& create(const std::string& attr);
45
+ TorchVitalAttr& create(const std::string& attr, bool force);
46
+ friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
47
+
48
+ ~TorchVital();
49
+ };
50
+
51
+ std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
52
+
53
+ // A way to access vitals by string names instead of by global reference.
54
+ // This enables access to vitals from the PythonAPI.
55
+ class TORCH_API APIVitals {
56
+ public:
57
+ bool vitals_enabled;
58
+
59
+ // Set any vital sign that was added to the map.
60
+ bool setVital(
61
+ const std::string& vital_name,
62
+ const std::string& attr_name,
63
+ const std::string& value,
64
+ bool force = false);
65
+ std::string readVitals();
66
+
67
+ APIVitals();
68
+
69
+ // Ensure this stays a singleton
70
+ APIVitals(APIVitals const& other) = delete;
71
+ APIVitals(APIVitals&& other) = delete;
72
+ APIVitals& operator=(const APIVitals&) = delete;
73
+ APIVitals& operator=(APIVitals&&) = delete;
74
+
75
+ private:
76
+ std::unordered_map<std::string, TorchVital> name_map_;
77
+ };
78
+
79
+ extern TORCH_API APIVitals VitalsAPI;
80
+
81
+ } // namespace at::vitals
82
+
83
+ #define TORCH_VITAL_DECLARE(name) \
84
+ TORCH_API at::vitals::TorchVital TorchVital_##name;
85
+
86
+ #define TORCH_VITAL_DEFINE(name) \
87
+ TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
88
+
89
+ #define TORCH_VITAL_BASE(name) TorchVital_##name
90
+
91
+ #define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
.venv/lib/python3.11/site-packages/torch/include/ATen/core/aten_interned_strings.h ADDED
@@ -0,0 +1,2264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from aten_interned_strings.h
4
+
5
+ #if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if including <ATen/core/symbol.h> for \
9
+ the c10::Symbol class would be sufficient, or if your change would be \
10
+ better placed in another file.
11
+ #endif
12
+
13
+ // ATen symbols correspond exactly to operators defined in ATen. Every
14
+ // symbol here corresponds exactly to an ATen operation defined in
15
+ // native_functions.yaml; attributes are in one-to-one correspondence
16
+ // with their ATen name.
17
+
18
+ #define FORALL_ATEN_BASE_SYMBOLS(_) \
19
+ _(aten, __and__) \
20
+ _(aten, __iand__) \
21
+ _(aten, __ilshift__) \
22
+ _(aten, __ior__) \
23
+ _(aten, __irshift__) \
24
+ _(aten, __ixor__) \
25
+ _(aten, __lshift__) \
26
+ _(aten, __or__) \
27
+ _(aten, __rshift__) \
28
+ _(aten, __xor__) \
29
+ _(aten, _adaptive_avg_pool2d) \
30
+ _(aten, _adaptive_avg_pool2d_backward) \
31
+ _(aten, _adaptive_avg_pool3d) \
32
+ _(aten, _adaptive_avg_pool3d_backward) \
33
+ _(aten, _add_batch_dim) \
34
+ _(aten, _add_relu) \
35
+ _(aten, _add_relu_) \
36
+ _(aten, _addmm_activation) \
37
+ _(aten, _aminmax) \
38
+ _(aten, _amp_foreach_non_finite_check_and_unscale) \
39
+ _(aten, _amp_foreach_non_finite_check_and_unscale_) \
40
+ _(aten, _amp_update_scale) \
41
+ _(aten, _amp_update_scale_) \
42
+ _(aten, _assert_async) \
43
+ _(aten, _assert_scalar) \
44
+ _(aten, _assert_tensor_metadata) \
45
+ _(aten, _autocast_to_full_precision) \
46
+ _(aten, _autocast_to_reduced_precision) \
47
+ _(aten, _backward) \
48
+ _(aten, _batch_norm_impl_index) \
49
+ _(aten, _batch_norm_impl_index_backward) \
50
+ _(aten, _batch_norm_no_update) \
51
+ _(aten, _batch_norm_with_update) \
52
+ _(aten, _batch_norm_with_update_functional) \
53
+ _(aten, _cast_Byte) \
54
+ _(aten, _cast_Char) \
55
+ _(aten, _cast_Double) \
56
+ _(aten, _cast_Float) \
57
+ _(aten, _cast_Half) \
58
+ _(aten, _cast_Int) \
59
+ _(aten, _cast_Long) \
60
+ _(aten, _cast_Short) \
61
+ _(aten, _cdist_backward) \
62
+ _(aten, _cdist_forward) \
63
+ _(aten, _cholesky_solve_helper) \
64
+ _(aten, _choose_qparams_per_tensor) \
65
+ _(aten, _chunk_cat) \
66
+ _(aten, _coalesce) \
67
+ _(aten, _coalesced) \
68
+ _(aten, _coalesced_) \
69
+ _(aten, _compute_linear_combination) \
70
+ _(aten, _conj) \
71
+ _(aten, _conj_copy) \
72
+ _(aten, _conj_physical) \
73
+ _(aten, _conv_depthwise2d) \
74
+ _(aten, _convert_indices_from_coo_to_csr) \
75
+ _(aten, _convert_indices_from_csr_to_coo) \
76
+ _(aten, _convert_weight_to_int4pack) \
77
+ _(aten, _convolution) \
78
+ _(aten, _convolution_double_backward) \
79
+ _(aten, _convolution_mode) \
80
+ _(aten, _copy_from) \
81
+ _(aten, _copy_from_and_resize) \
82
+ _(aten, _cslt_compress) \
83
+ _(aten, _cslt_sparse_mm) \
84
+ _(aten, _cslt_sparse_mm_search) \
85
+ _(aten, _ctc_loss) \
86
+ _(aten, _ctc_loss_backward) \
87
+ _(aten, _cudnn_ctc_loss) \
88
+ _(aten, _cudnn_init_dropout_state) \
89
+ _(aten, _cudnn_rnn) \
90
+ _(aten, _cudnn_rnn_backward) \
91
+ _(aten, _cudnn_rnn_flatten_weight) \
92
+ _(aten, _cufft_clear_plan_cache) \
93
+ _(aten, _cufft_get_plan_cache_max_size) \
94
+ _(aten, _cufft_get_plan_cache_size) \
95
+ _(aten, _cufft_set_plan_cache_max_size) \
96
+ _(aten, _cummax_helper) \
97
+ _(aten, _cummin_helper) \
98
+ _(aten, _debug_has_internal_overlap) \
99
+ _(aten, _dimI) \
100
+ _(aten, _dimV) \
101
+ _(aten, _dim_arange) \
102
+ _(aten, _dirichlet_grad) \
103
+ _(aten, _efficient_attention_backward) \
104
+ _(aten, _efficient_attention_forward) \
105
+ _(aten, _efficientzerotensor) \
106
+ _(aten, _embedding_bag) \
107
+ _(aten, _embedding_bag_backward) \
108
+ _(aten, _embedding_bag_dense_backward) \
109
+ _(aten, _embedding_bag_forward_only) \
110
+ _(aten, _embedding_bag_per_sample_weights_backward) \
111
+ _(aten, _embedding_bag_sparse_backward) \
112
+ _(aten, _empty_affine_quantized) \
113
+ _(aten, _empty_per_channel_affine_quantized) \
114
+ _(aten, _euclidean_dist) \
115
+ _(aten, _fake_quantize_learnable_per_channel_affine) \
116
+ _(aten, _fake_quantize_learnable_per_channel_affine_backward) \
117
+ _(aten, _fake_quantize_learnable_per_tensor_affine) \
118
+ _(aten, _fake_quantize_learnable_per_tensor_affine_backward) \
119
+ _(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \
120
+ _(aten, _fft_c2c) \
121
+ _(aten, _fft_c2r) \
122
+ _(aten, _fft_r2c) \
123
+ _(aten, _fill_mem_eff_dropout_mask) \
124
+ _(aten, _fill_mem_eff_dropout_mask_) \
125
+ _(aten, _flash_attention_backward) \
126
+ _(aten, _flash_attention_forward) \
127
+ _(aten, _foobar) \
128
+ _(aten, _foreach_abs) \
129
+ _(aten, _foreach_abs_) \
130
+ _(aten, _foreach_acos) \
131
+ _(aten, _foreach_acos_) \
132
+ _(aten, _foreach_add) \
133
+ _(aten, _foreach_add_) \
134
+ _(aten, _foreach_addcdiv) \
135
+ _(aten, _foreach_addcdiv_) \
136
+ _(aten, _foreach_addcmul) \
137
+ _(aten, _foreach_addcmul_) \
138
+ _(aten, _foreach_asin) \
139
+ _(aten, _foreach_asin_) \
140
+ _(aten, _foreach_atan) \
141
+ _(aten, _foreach_atan_) \
142
+ _(aten, _foreach_ceil) \
143
+ _(aten, _foreach_ceil_) \
144
+ _(aten, _foreach_clamp_max) \
145
+ _(aten, _foreach_clamp_max_) \
146
+ _(aten, _foreach_clamp_min) \
147
+ _(aten, _foreach_clamp_min_) \
148
+ _(aten, _foreach_copy) \
149
+ _(aten, _foreach_copy_) \
150
+ _(aten, _foreach_cos) \
151
+ _(aten, _foreach_cos_) \
152
+ _(aten, _foreach_cosh) \
153
+ _(aten, _foreach_cosh_) \
154
+ _(aten, _foreach_div) \
155
+ _(aten, _foreach_div_) \
156
+ _(aten, _foreach_erf) \
157
+ _(aten, _foreach_erf_) \
158
+ _(aten, _foreach_erfc) \
159
+ _(aten, _foreach_erfc_) \
160
+ _(aten, _foreach_exp) \
161
+ _(aten, _foreach_exp_) \
162
+ _(aten, _foreach_expm1) \
163
+ _(aten, _foreach_expm1_) \
164
+ _(aten, _foreach_floor) \
165
+ _(aten, _foreach_floor_) \
166
+ _(aten, _foreach_frac) \
167
+ _(aten, _foreach_frac_) \
168
+ _(aten, _foreach_lerp) \
169
+ _(aten, _foreach_lerp_) \
170
+ _(aten, _foreach_lgamma) \
171
+ _(aten, _foreach_lgamma_) \
172
+ _(aten, _foreach_log) \
173
+ _(aten, _foreach_log10) \
174
+ _(aten, _foreach_log10_) \
175
+ _(aten, _foreach_log1p) \
176
+ _(aten, _foreach_log1p_) \
177
+ _(aten, _foreach_log2) \
178
+ _(aten, _foreach_log2_) \
179
+ _(aten, _foreach_log_) \
180
+ _(aten, _foreach_max) \
181
+ _(aten, _foreach_maximum) \
182
+ _(aten, _foreach_maximum_) \
183
+ _(aten, _foreach_minimum) \
184
+ _(aten, _foreach_minimum_) \
185
+ _(aten, _foreach_mul) \
186
+ _(aten, _foreach_mul_) \
187
+ _(aten, _foreach_neg) \
188
+ _(aten, _foreach_neg_) \
189
+ _(aten, _foreach_norm) \
190
+ _(aten, _foreach_pow) \
191
+ _(aten, _foreach_pow_) \
192
+ _(aten, _foreach_reciprocal) \
193
+ _(aten, _foreach_reciprocal_) \
194
+ _(aten, _foreach_round) \
195
+ _(aten, _foreach_round_) \
196
+ _(aten, _foreach_sigmoid) \
197
+ _(aten, _foreach_sigmoid_) \
198
+ _(aten, _foreach_sign) \
199
+ _(aten, _foreach_sign_) \
200
+ _(aten, _foreach_sin) \
201
+ _(aten, _foreach_sin_) \
202
+ _(aten, _foreach_sinh) \
203
+ _(aten, _foreach_sinh_) \
204
+ _(aten, _foreach_sqrt) \
205
+ _(aten, _foreach_sqrt_) \
206
+ _(aten, _foreach_sub) \
207
+ _(aten, _foreach_sub_) \
208
+ _(aten, _foreach_tan) \
209
+ _(aten, _foreach_tan_) \
210
+ _(aten, _foreach_tanh) \
211
+ _(aten, _foreach_tanh_) \
212
+ _(aten, _foreach_trunc) \
213
+ _(aten, _foreach_trunc_) \
214
+ _(aten, _foreach_zero) \
215
+ _(aten, _foreach_zero_) \
216
+ _(aten, _functional_assert_async) \
217
+ _(aten, _functional_assert_scalar) \
218
+ _(aten, _functional_sym_constrain_range) \
219
+ _(aten, _functional_sym_constrain_range_for_size) \
220
+ _(aten, _fused_adagrad) \
221
+ _(aten, _fused_adagrad_) \
222
+ _(aten, _fused_adam) \
223
+ _(aten, _fused_adam_) \
224
+ _(aten, _fused_adamw) \
225
+ _(aten, _fused_adamw_) \
226
+ _(aten, _fused_dropout) \
227
+ _(aten, _fused_moving_avg_obs_fq_helper) \
228
+ _(aten, _fused_moving_avg_obs_fq_helper_functional) \
229
+ _(aten, _fused_sdp_choice) \
230
+ _(aten, _fused_sgd) \
231
+ _(aten, _fused_sgd_) \
232
+ _(aten, _fw_primal) \
233
+ _(aten, _fw_primal_copy) \
234
+ _(aten, _gather_sparse_backward) \
235
+ _(aten, _grid_sampler_2d_cpu_fallback) \
236
+ _(aten, _grid_sampler_2d_cpu_fallback_backward) \
237
+ _(aten, _has_compatible_shallow_copy_type) \
238
+ _(aten, _has_same_storage_numel) \
239
+ _(aten, _histogramdd_bin_edges) \
240
+ _(aten, _histogramdd_from_bin_cts) \
241
+ _(aten, _histogramdd_from_bin_tensors) \
242
+ _(aten, _index_put_impl) \
243
+ _(aten, _index_put_impl_) \
244
+ _(aten, _indices) \
245
+ _(aten, _indices_copy) \
246
+ _(aten, _int_mm) \
247
+ _(aten, _is_all_true) \
248
+ _(aten, _is_any_true) \
249
+ _(aten, _is_zerotensor) \
250
+ _(aten, _jagged_to_padded_dense_forward) \
251
+ _(aten, _lazy_clone) \
252
+ _(aten, _linalg_check_errors) \
253
+ _(aten, _linalg_det) \
254
+ _(aten, _linalg_eigh) \
255
+ _(aten, _linalg_eigvals) \
256
+ _(aten, _linalg_slogdet) \
257
+ _(aten, _linalg_solve_ex) \
258
+ _(aten, _linalg_svd) \
259
+ _(aten, _local_scalar_dense) \
260
+ _(aten, _log_softmax) \
261
+ _(aten, _log_softmax_backward_data) \
262
+ _(aten, _logcumsumexp) \
263
+ _(aten, _lstm_mps) \
264
+ _(aten, _lu_with_info) \
265
+ _(aten, _make_dep_token) \
266
+ _(aten, _make_dual) \
267
+ _(aten, _make_dual_copy) \
268
+ _(aten, _make_per_channel_quantized_tensor) \
269
+ _(aten, _make_per_tensor_quantized_tensor) \
270
+ _(aten, _masked_scale) \
271
+ _(aten, _masked_softmax) \
272
+ _(aten, _masked_softmax_backward) \
273
+ _(aten, _mixed_dtypes_linear) \
274
+ _(aten, _mkldnn_reshape) \
275
+ _(aten, _mkldnn_transpose) \
276
+ _(aten, _mkldnn_transpose_) \
277
+ _(aten, _mps_convolution) \
278
+ _(aten, _mps_convolution_transpose) \
279
+ _(aten, _native_batch_norm_legit) \
280
+ _(aten, _native_batch_norm_legit_functional) \
281
+ _(aten, _native_batch_norm_legit_no_training) \
282
+ _(aten, _native_multi_head_attention) \
283
+ _(aten, _neg_view) \
284
+ _(aten, _neg_view_copy) \
285
+ _(aten, _nested_compute_contiguous_strides_offsets) \
286
+ _(aten, _nested_from_padded) \
287
+ _(aten, _nested_from_padded_and_nested_example) \
288
+ _(aten, _nested_get_jagged_dummy) \
289
+ _(aten, _nested_get_lengths) \
290
+ _(aten, _nested_get_max_seqlen) \
291
+ _(aten, _nested_get_min_seqlen) \
292
+ _(aten, _nested_get_offsets) \
293
+ _(aten, _nested_get_ragged_idx) \
294
+ _(aten, _nested_get_values) \
295
+ _(aten, _nested_get_values_copy) \
296
+ _(aten, _nested_select_backward) \
297
+ _(aten, _nested_sum_backward) \
298
+ _(aten, _nested_tensor_from_mask) \
299
+ _(aten, _nested_tensor_from_mask_left_aligned) \
300
+ _(aten, _nested_tensor_from_tensor_list) \
301
+ _(aten, _nested_tensor_size) \
302
+ _(aten, _nested_tensor_softmax_with_shape) \
303
+ _(aten, _nested_tensor_storage_offsets) \
304
+ _(aten, _nested_tensor_strides) \
305
+ _(aten, _nested_view_from_buffer) \
306
+ _(aten, _nested_view_from_buffer_copy) \
307
+ _(aten, _nested_view_from_jagged) \
308
+ _(aten, _nested_view_from_jagged_copy) \
309
+ _(aten, _new_zeros_with_same_feature_meta) \
310
+ _(aten, _nnpack_available) \
311
+ _(aten, _nnpack_spatial_convolution) \
312
+ _(aten, _nnz) \
313
+ _(aten, _pack_padded_sequence) \
314
+ _(aten, _pack_padded_sequence_backward) \
315
+ _(aten, _pad_circular) \
316
+ _(aten, _pad_enum) \
317
+ _(aten, _pad_packed_sequence) \
318
+ _(aten, _padded_dense_to_jagged_forward) \
319
+ _(aten, _pdist_backward) \
320
+ _(aten, _pdist_forward) \
321
+ _(aten, _pin_memory) \
322
+ _(aten, _prelu_kernel) \
323
+ _(aten, _prelu_kernel_backward) \
324
+ _(aten, _print) \
325
+ _(aten, _propagate_xla_data) \
326
+ _(aten, _remove_batch_dim) \
327
+ _(aten, _reshape_alias) \
328
+ _(aten, _reshape_alias_copy) \
329
+ _(aten, _reshape_copy) \
330
+ _(aten, _reshape_from_tensor) \
331
+ _(aten, _resize_output) \
332
+ _(aten, _resize_output_) \
333
+ _(aten, _rowwise_prune) \
334
+ _(aten, _safe_softmax) \
335
+ _(aten, _sample_dirichlet) \
336
+ _(aten, _saturate_weight_to_fp16) \
337
+ _(aten, _scaled_dot_product_attention_math) \
338
+ _(aten, _scaled_dot_product_attention_math_for_mps) \
339
+ _(aten, _scaled_dot_product_cudnn_attention) \
340
+ _(aten, _scaled_dot_product_cudnn_attention_backward) \
341
+ _(aten, _scaled_dot_product_efficient_attention) \
342
+ _(aten, _scaled_dot_product_efficient_attention_backward) \
343
+ _(aten, _scaled_dot_product_flash_attention) \
344
+ _(aten, _scaled_dot_product_flash_attention_backward) \
345
+ _(aten, _scaled_dot_product_flash_attention_for_cpu) \
346
+ _(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \
347
+ _(aten, _scaled_dot_product_fused_attention_overrideable) \
348
+ _(aten, _scaled_dot_product_fused_attention_overrideable_backward) \
349
+ _(aten, _scaled_mm) \
350
+ _(aten, _segment_reduce_backward) \
351
+ _(aten, _shape_as_tensor) \
352
+ _(aten, _slow_conv2d_backward) \
353
+ _(aten, _slow_conv2d_forward) \
354
+ _(aten, _sobol_engine_draw) \
355
+ _(aten, _sobol_engine_ff) \
356
+ _(aten, _sobol_engine_ff_) \
357
+ _(aten, _sobol_engine_initialize_state) \
358
+ _(aten, _sobol_engine_initialize_state_) \
359
+ _(aten, _sobol_engine_scramble) \
360
+ _(aten, _sobol_engine_scramble_) \
361
+ _(aten, _softmax) \
362
+ _(aten, _softmax_backward_data) \
363
+ _(aten, _sparse_addmm) \
364
+ _(aten, _sparse_broadcast_to) \
365
+ _(aten, _sparse_broadcast_to_copy) \
366
+ _(aten, _sparse_bsc_tensor_unsafe) \
367
+ _(aten, _sparse_bsr_tensor_unsafe) \
368
+ _(aten, _sparse_compressed_tensor_unsafe) \
369
+ _(aten, _sparse_compressed_tensor_with_dims) \
370
+ _(aten, _sparse_coo_tensor_unsafe) \
371
+ _(aten, _sparse_coo_tensor_with_dims) \
372
+ _(aten, _sparse_coo_tensor_with_dims_and_tensors) \
373
+ _(aten, _sparse_csc_tensor_unsafe) \
374
+ _(aten, _sparse_csr_prod) \
375
+ _(aten, _sparse_csr_sum) \
376
+ _(aten, _sparse_csr_tensor_unsafe) \
377
+ _(aten, _sparse_log_softmax) \
378
+ _(aten, _sparse_log_softmax_backward_data) \
379
+ _(aten, _sparse_mask_projection) \
380
+ _(aten, _sparse_mm) \
381
+ _(aten, _sparse_mm_reduce_impl) \
382
+ _(aten, _sparse_mm_reduce_impl_backward) \
383
+ _(aten, _sparse_semi_structured_addmm) \
384
+ _(aten, _sparse_semi_structured_apply) \
385
+ _(aten, _sparse_semi_structured_apply_dense) \
386
+ _(aten, _sparse_semi_structured_linear) \
387
+ _(aten, _sparse_semi_structured_mm) \
388
+ _(aten, _sparse_semi_structured_tile) \
389
+ _(aten, _sparse_softmax) \
390
+ _(aten, _sparse_softmax_backward_data) \
391
+ _(aten, _sparse_sparse_matmul) \
392
+ _(aten, _sparse_sum) \
393
+ _(aten, _sparse_sum_backward) \
394
+ _(aten, _spdiags) \
395
+ _(aten, _spsolve) \
396
+ _(aten, _stack) \
397
+ _(aten, _standard_gamma) \
398
+ _(aten, _standard_gamma_grad) \
399
+ _(aten, _test_ambiguous_defaults) \
400
+ _(aten, _test_autograd_multiple_dispatch) \
401
+ _(aten, _test_autograd_multiple_dispatch_view) \
402
+ _(aten, _test_autograd_multiple_dispatch_view_copy) \
403
+ _(aten, _test_check_tensor) \
404
+ _(aten, _test_functorch_fallback) \
405
+ _(aten, _test_optional_filled_intlist) \
406
+ _(aten, _test_optional_floatlist) \
407
+ _(aten, _test_optional_intlist) \
408
+ _(aten, _test_parallel_materialize) \
409
+ _(aten, _test_serialization_subcmul) \
410
+ _(aten, _test_string_default) \
411
+ _(aten, _test_warn_in_autograd) \
412
+ _(aten, _thnn_differentiable_gru_cell_backward) \
413
+ _(aten, _thnn_differentiable_lstm_cell_backward) \
414
+ _(aten, _thnn_fused_gru_cell) \
415
+ _(aten, _thnn_fused_gru_cell_backward) \
416
+ _(aten, _thnn_fused_lstm_cell) \
417
+ _(aten, _thnn_fused_lstm_cell_backward) \
418
+ _(aten, _thnn_fused_lstm_cell_backward_impl) \
419
+ _(aten, _to_copy) \
420
+ _(aten, _to_cpu) \
421
+ _(aten, _to_dense) \
422
+ _(aten, _to_sparse) \
423
+ _(aten, _to_sparse_bsc) \
424
+ _(aten, _to_sparse_bsr) \
425
+ _(aten, _to_sparse_csc) \
426
+ _(aten, _to_sparse_csr) \
427
+ _(aten, _to_sparse_semi_structured) \
428
+ _(aten, _transform_bias_rescale_qkv) \
429
+ _(aten, _transformer_encoder_layer_fwd) \
430
+ _(aten, _trilinear) \
431
+ _(aten, _triton_multi_head_attention) \
432
+ _(aten, _triton_scaled_dot_attention) \
433
+ _(aten, _unique) \
434
+ _(aten, _unique2) \
435
+ _(aten, _unpack_dual) \
436
+ _(aten, _unsafe_index) \
437
+ _(aten, _unsafe_index_put) \
438
+ _(aten, _unsafe_masked_index) \
439
+ _(aten, _unsafe_masked_index_put_accumulate) \
440
+ _(aten, _unsafe_view) \
441
+ _(aten, _upsample_bicubic2d_aa) \
442
+ _(aten, _upsample_bicubic2d_aa_backward) \
443
+ _(aten, _upsample_bilinear2d_aa) \
444
+ _(aten, _upsample_bilinear2d_aa_backward) \
445
+ _(aten, _upsample_nearest_exact1d) \
446
+ _(aten, _upsample_nearest_exact1d_backward) \
447
+ _(aten, _upsample_nearest_exact2d) \
448
+ _(aten, _upsample_nearest_exact2d_backward) \
449
+ _(aten, _upsample_nearest_exact3d) \
450
+ _(aten, _upsample_nearest_exact3d_backward) \
451
+ _(aten, _use_cudnn_ctc_loss) \
452
+ _(aten, _use_cudnn_rnn_flatten_weight) \
453
+ _(aten, _validate_compressed_sparse_indices) \
454
+ _(aten, _validate_sparse_bsc_tensor_args) \
455
+ _(aten, _validate_sparse_bsr_tensor_args) \
456
+ _(aten, _validate_sparse_compressed_tensor_args) \
457
+ _(aten, _validate_sparse_coo_tensor_args) \
458
+ _(aten, _validate_sparse_csc_tensor_args) \
459
+ _(aten, _validate_sparse_csr_tensor_args) \
460
+ _(aten, _values) \
461
+ _(aten, _values_copy) \
462
+ _(aten, _version) \
463
+ _(aten, _weight_int4pack_mm) \
464
+ _(aten, _weight_int8pack_mm) \
465
+ _(aten, _weight_norm) \
466
+ _(aten, _weight_norm_differentiable_backward) \
467
+ _(aten, _weight_norm_interface) \
468
+ _(aten, _weight_norm_interface_backward) \
469
+ _(aten, _wrapped_linear_prepack) \
470
+ _(aten, _wrapped_quantized_linear_prepacked) \
471
+ _(aten, abs) \
472
+ _(aten, abs_) \
473
+ _(aten, absolute) \
474
+ _(aten, absolute_) \
475
+ _(aten, acos) \
476
+ _(aten, acos_) \
477
+ _(aten, acosh) \
478
+ _(aten, acosh_) \
479
+ _(aten, adaptive_avg_pool1d) \
480
+ _(aten, adaptive_avg_pool2d) \
481
+ _(aten, adaptive_avg_pool3d) \
482
+ _(aten, adaptive_avg_pool3d_backward) \
483
+ _(aten, adaptive_max_pool1d) \
484
+ _(aten, adaptive_max_pool2d) \
485
+ _(aten, adaptive_max_pool2d_backward) \
486
+ _(aten, adaptive_max_pool3d) \
487
+ _(aten, adaptive_max_pool3d_backward) \
488
+ _(aten, add) \
489
+ _(aten, add_) \
490
+ _(aten, addbmm) \
491
+ _(aten, addbmm_) \
492
+ _(aten, addcdiv) \
493
+ _(aten, addcdiv_) \
494
+ _(aten, addcmul) \
495
+ _(aten, addcmul_) \
496
+ _(aten, addmm) \
497
+ _(aten, addmm_) \
498
+ _(aten, addmv) \
499
+ _(aten, addmv_) \
500
+ _(aten, addr) \
501
+ _(aten, addr_) \
502
+ _(aten, adjoint) \
503
+ _(aten, affine_grid_generator) \
504
+ _(aten, affine_grid_generator_backward) \
505
+ _(aten, alias) \
506
+ _(aten, alias_copy) \
507
+ _(aten, align_as) \
508
+ _(aten, align_tensors) \
509
+ _(aten, align_to) \
510
+ _(aten, all) \
511
+ _(aten, allclose) \
512
+ _(aten, alpha_dropout) \
513
+ _(aten, alpha_dropout_) \
514
+ _(aten, amax) \
515
+ _(aten, amin) \
516
+ _(aten, aminmax) \
517
+ _(aten, angle) \
518
+ _(aten, any) \
519
+ _(aten, arange) \
520
+ _(aten, arccos) \
521
+ _(aten, arccos_) \
522
+ _(aten, arccosh) \
523
+ _(aten, arccosh_) \
524
+ _(aten, arcsin) \
525
+ _(aten, arcsin_) \
526
+ _(aten, arcsinh) \
527
+ _(aten, arcsinh_) \
528
+ _(aten, arctan) \
529
+ _(aten, arctan2) \
530
+ _(aten, arctan2_) \
531
+ _(aten, arctan_) \
532
+ _(aten, arctanh) \
533
+ _(aten, arctanh_) \
534
+ _(aten, argmax) \
535
+ _(aten, argmin) \
536
+ _(aten, argsort) \
537
+ _(aten, argwhere) \
538
+ _(aten, as_strided) \
539
+ _(aten, as_strided_) \
540
+ _(aten, as_strided_copy) \
541
+ _(aten, as_strided_scatter) \
542
+ _(aten, asin) \
543
+ _(aten, asin_) \
544
+ _(aten, asinh) \
545
+ _(aten, asinh_) \
546
+ _(aten, atan) \
547
+ _(aten, atan2) \
548
+ _(aten, atan2_) \
549
+ _(aten, atan_) \
550
+ _(aten, atanh) \
551
+ _(aten, atanh_) \
552
+ _(aten, atleast_1d) \
553
+ _(aten, atleast_2d) \
554
+ _(aten, atleast_3d) \
555
+ _(aten, avg_pool1d) \
556
+ _(aten, avg_pool2d) \
557
+ _(aten, avg_pool2d_backward) \
558
+ _(aten, avg_pool3d) \
559
+ _(aten, avg_pool3d_backward) \
560
+ _(aten, baddbmm) \
561
+ _(aten, baddbmm_) \
562
+ _(aten, bartlett_window) \
563
+ _(aten, batch_norm) \
564
+ _(aten, batch_norm_backward) \
565
+ _(aten, batch_norm_backward_elemt) \
566
+ _(aten, batch_norm_backward_reduce) \
567
+ _(aten, batch_norm_elemt) \
568
+ _(aten, batch_norm_gather_stats) \
569
+ _(aten, batch_norm_gather_stats_with_counts) \
570
+ _(aten, batch_norm_stats) \
571
+ _(aten, batch_norm_update_stats) \
572
+ _(aten, bernoulli) \
573
+ _(aten, bernoulli_) \
574
+ _(aten, bilinear) \
575
+ _(aten, binary_cross_entropy) \
576
+ _(aten, binary_cross_entropy_backward) \
577
+ _(aten, binary_cross_entropy_with_logits) \
578
+ _(aten, bincount) \
579
+ _(aten, binomial) \
580
+ _(aten, bitwise_and) \
581
+ _(aten, bitwise_and_) \
582
+ _(aten, bitwise_left_shift) \
583
+ _(aten, bitwise_left_shift_) \
584
+ _(aten, bitwise_not) \
585
+ _(aten, bitwise_not_) \
586
+ _(aten, bitwise_or) \
587
+ _(aten, bitwise_or_) \
588
+ _(aten, bitwise_right_shift) \
589
+ _(aten, bitwise_right_shift_) \
590
+ _(aten, bitwise_xor) \
591
+ _(aten, bitwise_xor_) \
592
+ _(aten, blackman_window) \
593
+ _(aten, block_diag) \
594
+ _(aten, bmm) \
595
+ _(aten, broadcast_tensors) \
596
+ _(aten, broadcast_to) \
597
+ _(aten, bucketize) \
598
+ _(aten, can_cast) \
599
+ _(aten, cartesian_prod) \
600
+ _(aten, cat) \
601
+ _(aten, cauchy) \
602
+ _(aten, cauchy_) \
603
+ _(aten, ccol_indices) \
604
+ _(aten, ccol_indices_copy) \
605
+ _(aten, cdist) \
606
+ _(aten, ceil) \
607
+ _(aten, ceil_) \
608
+ _(aten, celu) \
609
+ _(aten, celu_) \
610
+ _(aten, chain_matmul) \
611
+ _(aten, chalf) \
612
+ _(aten, channel_shuffle) \
613
+ _(aten, cholesky) \
614
+ _(aten, cholesky_inverse) \
615
+ _(aten, cholesky_solve) \
616
+ _(aten, choose_qparams_optimized) \
617
+ _(aten, chunk) \
618
+ _(aten, clamp) \
619
+ _(aten, clamp_) \
620
+ _(aten, clamp_max) \
621
+ _(aten, clamp_max_) \
622
+ _(aten, clamp_min) \
623
+ _(aten, clamp_min_) \
624
+ _(aten, clip) \
625
+ _(aten, clip_) \
626
+ _(aten, clone) \
627
+ _(aten, coalesce) \
628
+ _(aten, col2im) \
629
+ _(aten, col_indices) \
630
+ _(aten, col_indices_copy) \
631
+ _(aten, column_stack) \
632
+ _(aten, combinations) \
633
+ _(aten, complex) \
634
+ _(aten, concat) \
635
+ _(aten, concatenate) \
636
+ _(aten, conj) \
637
+ _(aten, conj_physical) \
638
+ _(aten, conj_physical_) \
639
+ _(aten, constant_pad_nd) \
640
+ _(aten, contiguous) \
641
+ _(aten, conv1d) \
642
+ _(aten, conv2d) \
643
+ _(aten, conv3d) \
644
+ _(aten, conv_depthwise3d) \
645
+ _(aten, conv_tbc) \
646
+ _(aten, conv_tbc_backward) \
647
+ _(aten, conv_transpose1d) \
648
+ _(aten, conv_transpose2d) \
649
+ _(aten, conv_transpose3d) \
650
+ _(aten, convolution) \
651
+ _(aten, convolution_backward) \
652
+ _(aten, convolution_backward_overrideable) \
653
+ _(aten, convolution_overrideable) \
654
+ _(aten, copy) \
655
+ _(aten, copy_) \
656
+ _(aten, copy_sparse_to_sparse) \
657
+ _(aten, copy_sparse_to_sparse_) \
658
+ _(aten, copysign) \
659
+ _(aten, copysign_) \
660
+ _(aten, corrcoef) \
661
+ _(aten, cos) \
662
+ _(aten, cos_) \
663
+ _(aten, cosh) \
664
+ _(aten, cosh_) \
665
+ _(aten, cosine_embedding_loss) \
666
+ _(aten, cosine_similarity) \
667
+ _(aten, count_nonzero) \
668
+ _(aten, cov) \
669
+ _(aten, cross) \
670
+ _(aten, cross_entropy_loss) \
671
+ _(aten, crow_indices) \
672
+ _(aten, crow_indices_copy) \
673
+ _(aten, ctc_loss) \
674
+ _(aten, cudnn_affine_grid_generator) \
675
+ _(aten, cudnn_affine_grid_generator_backward) \
676
+ _(aten, cudnn_batch_norm) \
677
+ _(aten, cudnn_batch_norm_backward) \
678
+ _(aten, cudnn_convolution) \
679
+ _(aten, cudnn_convolution_add_relu) \
680
+ _(aten, cudnn_convolution_relu) \
681
+ _(aten, cudnn_convolution_transpose) \
682
+ _(aten, cudnn_grid_sampler) \
683
+ _(aten, cudnn_grid_sampler_backward) \
684
+ _(aten, cudnn_is_acceptable) \
685
+ _(aten, cummax) \
686
+ _(aten, cummaxmin_backward) \
687
+ _(aten, cummin) \
688
+ _(aten, cumprod) \
689
+ _(aten, cumprod_) \
690
+ _(aten, cumprod_backward) \
691
+ _(aten, cumsum) \
692
+ _(aten, cumsum_) \
693
+ _(aten, cumulative_trapezoid) \
694
+ _(aten, data) \
695
+ _(aten, deg2rad) \
696
+ _(aten, deg2rad_) \
697
+ _(aten, dense_dim) \
698
+ _(aten, dequantize) \
699
+ _(aten, det) \
700
+ _(aten, detach) \
701
+ _(aten, detach_) \
702
+ _(aten, detach_copy) \
703
+ _(aten, diag) \
704
+ _(aten, diag_embed) \
705
+ _(aten, diagflat) \
706
+ _(aten, diagonal) \
707
+ _(aten, diagonal_backward) \
708
+ _(aten, diagonal_copy) \
709
+ _(aten, diagonal_scatter) \
710
+ _(aten, diff) \
711
+ _(aten, digamma) \
712
+ _(aten, digamma_) \
713
+ _(aten, dist) \
714
+ _(aten, div) \
715
+ _(aten, div_) \
716
+ _(aten, divide) \
717
+ _(aten, divide_) \
718
+ _(aten, dot) \
719
+ _(aten, dropout) \
720
+ _(aten, dropout_) \
721
+ _(aten, dsplit) \
722
+ _(aten, dstack) \
723
+ _(aten, einsum) \
724
+ _(aten, elu) \
725
+ _(aten, elu_) \
726
+ _(aten, elu_backward) \
727
+ _(aten, embedding) \
728
+ _(aten, embedding_backward) \
729
+ _(aten, embedding_bag) \
730
+ _(aten, embedding_dense_backward) \
731
+ _(aten, embedding_renorm) \
732
+ _(aten, embedding_renorm_) \
733
+ _(aten, embedding_sparse_backward) \
734
+ _(aten, empty) \
735
+ _(aten, empty_like) \
736
+ _(aten, empty_permuted) \
737
+ _(aten, empty_quantized) \
738
+ _(aten, empty_strided) \
739
+ _(aten, eq) \
740
+ _(aten, eq_) \
741
+ _(aten, equal) \
742
+ _(aten, erf) \
743
+ _(aten, erf_) \
744
+ _(aten, erfc) \
745
+ _(aten, erfc_) \
746
+ _(aten, erfinv) \
747
+ _(aten, erfinv_) \
748
+ _(aten, exp) \
749
+ _(aten, exp2) \
750
+ _(aten, exp2_) \
751
+ _(aten, exp_) \
752
+ _(aten, expand) \
753
+ _(aten, expand_as) \
754
+ _(aten, expand_copy) \
755
+ _(aten, expm1) \
756
+ _(aten, expm1_) \
757
+ _(aten, exponential) \
758
+ _(aten, exponential_) \
759
+ _(aten, eye) \
760
+ _(aten, fake_quantize_per_channel_affine) \
761
+ _(aten, fake_quantize_per_channel_affine_cachemask) \
762
+ _(aten, fake_quantize_per_channel_affine_cachemask_backward) \
763
+ _(aten, fake_quantize_per_tensor_affine) \
764
+ _(aten, fake_quantize_per_tensor_affine_cachemask) \
765
+ _(aten, fake_quantize_per_tensor_affine_cachemask_backward) \
766
+ _(aten, fbgemm_linear_fp16_weight) \
767
+ _(aten, fbgemm_linear_fp16_weight_fp32_activation) \
768
+ _(aten, fbgemm_linear_int8_weight) \
769
+ _(aten, fbgemm_linear_int8_weight_fp32_activation) \
770
+ _(aten, fbgemm_linear_quantize_weight) \
771
+ _(aten, fbgemm_pack_gemm_matrix_fp16) \
772
+ _(aten, fbgemm_pack_quantized_matrix) \
773
+ _(aten, feature_alpha_dropout) \
774
+ _(aten, feature_alpha_dropout_) \
775
+ _(aten, feature_dropout) \
776
+ _(aten, feature_dropout_) \
777
+ _(aten, fft_fft) \
778
+ _(aten, fft_fft2) \
779
+ _(aten, fft_fftfreq) \
780
+ _(aten, fft_fftn) \
781
+ _(aten, fft_fftshift) \
782
+ _(aten, fft_hfft) \
783
+ _(aten, fft_hfft2) \
784
+ _(aten, fft_hfftn) \
785
+ _(aten, fft_ifft) \
786
+ _(aten, fft_ifft2) \
787
+ _(aten, fft_ifftn) \
788
+ _(aten, fft_ifftshift) \
789
+ _(aten, fft_ihfft) \
790
+ _(aten, fft_ihfft2) \
791
+ _(aten, fft_ihfftn) \
792
+ _(aten, fft_irfft) \
793
+ _(aten, fft_irfft2) \
794
+ _(aten, fft_irfftn) \
795
+ _(aten, fft_rfft) \
796
+ _(aten, fft_rfft2) \
797
+ _(aten, fft_rfftfreq) \
798
+ _(aten, fft_rfftn) \
799
+ _(aten, fill) \
800
+ _(aten, fill_) \
801
+ _(aten, fill_diagonal) \
802
+ _(aten, fill_diagonal_) \
803
+ _(aten, fix) \
804
+ _(aten, fix_) \
805
+ _(aten, flatten) \
806
+ _(aten, flatten_dense_tensors) \
807
+ _(aten, flip) \
808
+ _(aten, fliplr) \
809
+ _(aten, flipud) \
810
+ _(aten, float_power) \
811
+ _(aten, float_power_) \
812
+ _(aten, floor) \
813
+ _(aten, floor_) \
814
+ _(aten, floor_divide) \
815
+ _(aten, floor_divide_) \
816
+ _(aten, fmax) \
817
+ _(aten, fmin) \
818
+ _(aten, fmod) \
819
+ _(aten, fmod_) \
820
+ _(aten, frac) \
821
+ _(aten, frac_) \
822
+ _(aten, fractional_max_pool2d) \
823
+ _(aten, fractional_max_pool2d_backward) \
824
+ _(aten, fractional_max_pool3d) \
825
+ _(aten, fractional_max_pool3d_backward) \
826
+ _(aten, frexp) \
827
+ _(aten, frobenius_norm) \
828
+ _(aten, from_file) \
829
+ _(aten, full) \
830
+ _(aten, full_like) \
831
+ _(aten, fused_moving_avg_obs_fake_quant) \
832
+ _(aten, gather) \
833
+ _(aten, gather_backward) \
834
+ _(aten, gcd) \
835
+ _(aten, gcd_) \
836
+ _(aten, ge) \
837
+ _(aten, ge_) \
838
+ _(aten, gelu) \
839
+ _(aten, gelu_) \
840
+ _(aten, gelu_backward) \
841
+ _(aten, geometric) \
842
+ _(aten, geometric_) \
843
+ _(aten, geqrf) \
844
+ _(aten, ger) \
845
+ _(aten, glu) \
846
+ _(aten, glu_backward) \
847
+ _(aten, glu_backward_jvp) \
848
+ _(aten, glu_jvp) \
849
+ _(aten, gradient) \
850
+ _(aten, greater) \
851
+ _(aten, greater_) \
852
+ _(aten, greater_equal) \
853
+ _(aten, greater_equal_) \
854
+ _(aten, grid_sampler) \
855
+ _(aten, grid_sampler_2d) \
856
+ _(aten, grid_sampler_2d_backward) \
857
+ _(aten, grid_sampler_3d) \
858
+ _(aten, grid_sampler_3d_backward) \
859
+ _(aten, group_norm) \
860
+ _(aten, gru) \
861
+ _(aten, gru_cell) \
862
+ _(aten, gt) \
863
+ _(aten, gt_) \
864
+ _(aten, hamming_window) \
865
+ _(aten, hann_window) \
866
+ _(aten, hardshrink) \
867
+ _(aten, hardshrink_backward) \
868
+ _(aten, hardsigmoid) \
869
+ _(aten, hardsigmoid_) \
870
+ _(aten, hardsigmoid_backward) \
871
+ _(aten, hardswish) \
872
+ _(aten, hardswish_) \
873
+ _(aten, hardswish_backward) \
874
+ _(aten, hardtanh) \
875
+ _(aten, hardtanh_) \
876
+ _(aten, hardtanh_backward) \
877
+ _(aten, heaviside) \
878
+ _(aten, heaviside_) \
879
+ _(aten, hinge_embedding_loss) \
880
+ _(aten, histc) \
881
+ _(aten, histogram) \
882
+ _(aten, histogramdd) \
883
+ _(aten, hsplit) \
884
+ _(aten, hspmm) \
885
+ _(aten, hstack) \
886
+ _(aten, huber_loss) \
887
+ _(aten, huber_loss_backward) \
888
+ _(aten, hypot) \
889
+ _(aten, hypot_) \
890
+ _(aten, i0) \
891
+ _(aten, i0_) \
892
+ _(aten, igamma) \
893
+ _(aten, igamma_) \
894
+ _(aten, igammac) \
895
+ _(aten, igammac_) \
896
+ _(aten, im2col) \
897
+ _(aten, imag) \
898
+ _(aten, index) \
899
+ _(aten, index_add) \
900
+ _(aten, index_add_) \
901
+ _(aten, index_copy) \
902
+ _(aten, index_copy_) \
903
+ _(aten, index_fill) \
904
+ _(aten, index_fill_) \
905
+ _(aten, index_put) \
906
+ _(aten, index_put_) \
907
+ _(aten, index_reduce) \
908
+ _(aten, index_reduce_) \
909
+ _(aten, index_select) \
910
+ _(aten, index_select_backward) \
911
+ _(aten, indices) \
912
+ _(aten, indices_copy) \
913
+ _(aten, infinitely_differentiable_gelu_backward) \
914
+ _(aten, inner) \
915
+ _(aten, instance_norm) \
916
+ _(aten, int_repr) \
917
+ _(aten, inverse) \
918
+ _(aten, is_coalesced) \
919
+ _(aten, is_complex) \
920
+ _(aten, is_conj) \
921
+ _(aten, is_distributed) \
922
+ _(aten, is_floating_point) \
923
+ _(aten, is_inference) \
924
+ _(aten, is_leaf) \
925
+ _(aten, is_neg) \
926
+ _(aten, is_nonzero) \
927
+ _(aten, is_pinned) \
928
+ _(aten, is_same_size) \
929
+ _(aten, is_set_to) \
930
+ _(aten, is_signed) \
931
+ _(aten, is_vulkan_available) \
932
+ _(aten, isclose) \
933
+ _(aten, isfinite) \
934
+ _(aten, isin) \
935
+ _(aten, isinf) \
936
+ _(aten, isnan) \
937
+ _(aten, isneginf) \
938
+ _(aten, isposinf) \
939
+ _(aten, isreal) \
940
+ _(aten, istft) \
941
+ _(aten, item) \
942
+ _(aten, kaiser_window) \
943
+ _(aten, kl_div) \
944
+ _(aten, kron) \
945
+ _(aten, kthvalue) \
946
+ _(aten, l1_loss) \
947
+ _(aten, layer_norm) \
948
+ _(aten, lcm) \
949
+ _(aten, lcm_) \
950
+ _(aten, ldexp) \
951
+ _(aten, ldexp_) \
952
+ _(aten, le) \
953
+ _(aten, le_) \
954
+ _(aten, leaky_relu) \
955
+ _(aten, leaky_relu_) \
956
+ _(aten, leaky_relu_backward) \
957
+ _(aten, lerp) \
958
+ _(aten, lerp_) \
959
+ _(aten, less) \
960
+ _(aten, less_) \
961
+ _(aten, less_equal) \
962
+ _(aten, less_equal_) \
963
+ _(aten, lgamma) \
964
+ _(aten, lgamma_) \
965
+ _(aten, lift) \
966
+ _(aten, lift_fresh) \
967
+ _(aten, lift_fresh_copy) \
968
+ _(aten, linalg_cholesky) \
969
+ _(aten, linalg_cholesky_ex) \
970
+ _(aten, linalg_cond) \
971
+ _(aten, linalg_cross) \
972
+ _(aten, linalg_det) \
973
+ _(aten, linalg_diagonal) \
974
+ _(aten, linalg_eig) \
975
+ _(aten, linalg_eigh) \
976
+ _(aten, linalg_eigvals) \
977
+ _(aten, linalg_eigvalsh) \
978
+ _(aten, linalg_householder_product) \
979
+ _(aten, linalg_inv) \
980
+ _(aten, linalg_inv_ex) \
981
+ _(aten, linalg_ldl_factor) \
982
+ _(aten, linalg_ldl_factor_ex) \
983
+ _(aten, linalg_ldl_solve) \
984
+ _(aten, linalg_lstsq) \
985
+ _(aten, linalg_lu) \
986
+ _(aten, linalg_lu_factor) \
987
+ _(aten, linalg_lu_factor_ex) \
988
+ _(aten, linalg_lu_solve) \
989
+ _(aten, linalg_matmul) \
990
+ _(aten, linalg_matrix_exp) \
991
+ _(aten, linalg_matrix_norm) \
992
+ _(aten, linalg_matrix_power) \
993
+ _(aten, linalg_matrix_rank) \
994
+ _(aten, linalg_multi_dot) \
995
+ _(aten, linalg_norm) \
996
+ _(aten, linalg_pinv) \
997
+ _(aten, linalg_qr) \
998
+ _(aten, linalg_slogdet) \
999
+ _(aten, linalg_solve) \
1000
+ _(aten, linalg_solve_ex) \
1001
+ _(aten, linalg_solve_triangular) \
1002
+ _(aten, linalg_svd) \
1003
+ _(aten, linalg_svdvals) \
1004
+ _(aten, linalg_tensorinv) \
1005
+ _(aten, linalg_tensorsolve) \
1006
+ _(aten, linalg_vander) \
1007
+ _(aten, linalg_vecdot) \
1008
+ _(aten, linalg_vector_norm) \
1009
+ _(aten, linear) \
1010
+ _(aten, linear_backward) \
1011
+ _(aten, linspace) \
1012
+ _(aten, log) \
1013
+ _(aten, log10) \
1014
+ _(aten, log10_) \
1015
+ _(aten, log1p) \
1016
+ _(aten, log1p_) \
1017
+ _(aten, log2) \
1018
+ _(aten, log2_) \
1019
+ _(aten, log_) \
1020
+ _(aten, log_normal) \
1021
+ _(aten, log_normal_) \
1022
+ _(aten, log_sigmoid) \
1023
+ _(aten, log_sigmoid_backward) \
1024
+ _(aten, log_sigmoid_forward) \
1025
+ _(aten, log_softmax) \
1026
+ _(aten, logaddexp) \
1027
+ _(aten, logaddexp2) \
1028
+ _(aten, logcumsumexp) \
1029
+ _(aten, logdet) \
1030
+ _(aten, logical_and) \
1031
+ _(aten, logical_and_) \
1032
+ _(aten, logical_not) \
1033
+ _(aten, logical_not_) \
1034
+ _(aten, logical_or) \
1035
+ _(aten, logical_or_) \
1036
+ _(aten, logical_xor) \
1037
+ _(aten, logical_xor_) \
1038
+ _(aten, logit) \
1039
+ _(aten, logit_) \
1040
+ _(aten, logit_backward) \
1041
+ _(aten, logspace) \
1042
+ _(aten, logsumexp) \
1043
+ _(aten, lshift) \
1044
+ _(aten, lstm) \
1045
+ _(aten, lstm_cell) \
1046
+ _(aten, lstm_mps_backward) \
1047
+ _(aten, lt) \
1048
+ _(aten, lt_) \
1049
+ _(aten, lu_solve) \
1050
+ _(aten, lu_unpack) \
1051
+ _(aten, mH) \
1052
+ _(aten, mT) \
1053
+ _(aten, margin_ranking_loss) \
1054
+ _(aten, masked_fill) \
1055
+ _(aten, masked_fill_) \
1056
+ _(aten, masked_scatter) \
1057
+ _(aten, masked_scatter_) \
1058
+ _(aten, masked_scatter_backward) \
1059
+ _(aten, masked_select) \
1060
+ _(aten, masked_select_backward) \
1061
+ _(aten, matmul) \
1062
+ _(aten, matmul_backward) \
1063
+ _(aten, matrix_H) \
1064
+ _(aten, matrix_exp) \
1065
+ _(aten, matrix_exp_backward) \
1066
+ _(aten, matrix_power) \
1067
+ _(aten, max) \
1068
+ _(aten, max_pool1d) \
1069
+ _(aten, max_pool1d_with_indices) \
1070
+ _(aten, max_pool2d) \
1071
+ _(aten, max_pool2d_backward) \
1072
+ _(aten, max_pool2d_with_indices) \
1073
+ _(aten, max_pool2d_with_indices_backward) \
1074
+ _(aten, max_pool3d) \
1075
+ _(aten, max_pool3d_with_indices) \
1076
+ _(aten, max_pool3d_with_indices_backward) \
1077
+ _(aten, max_unpool2d) \
1078
+ _(aten, max_unpool3d) \
1079
+ _(aten, maximum) \
1080
+ _(aten, mean) \
1081
+ _(aten, median) \
1082
+ _(aten, meshgrid) \
1083
+ _(aten, min) \
1084
+ _(aten, minimum) \
1085
+ _(aten, miopen_batch_norm) \
1086
+ _(aten, miopen_batch_norm_backward) \
1087
+ _(aten, miopen_convolution) \
1088
+ _(aten, miopen_convolution_add_relu) \
1089
+ _(aten, miopen_convolution_relu) \
1090
+ _(aten, miopen_convolution_transpose) \
1091
+ _(aten, miopen_depthwise_convolution) \
1092
+ _(aten, miopen_rnn) \
1093
+ _(aten, miopen_rnn_backward) \
1094
+ _(aten, mish) \
1095
+ _(aten, mish_) \
1096
+ _(aten, mish_backward) \
1097
+ _(aten, mkldnn_adaptive_avg_pool2d) \
1098
+ _(aten, mkldnn_adaptive_avg_pool2d_backward) \
1099
+ _(aten, mkldnn_convolution) \
1100
+ _(aten, mkldnn_linear) \
1101
+ _(aten, mkldnn_linear_backward) \
1102
+ _(aten, mkldnn_linear_backward_input) \
1103
+ _(aten, mkldnn_linear_backward_weights) \
1104
+ _(aten, mkldnn_max_pool2d) \
1105
+ _(aten, mkldnn_max_pool2d_backward) \
1106
+ _(aten, mkldnn_max_pool3d) \
1107
+ _(aten, mkldnn_max_pool3d_backward) \
1108
+ _(aten, mkldnn_reorder_conv2d_weight) \
1109
+ _(aten, mkldnn_reorder_conv3d_weight) \
1110
+ _(aten, mkldnn_rnn_layer) \
1111
+ _(aten, mkldnn_rnn_layer_backward) \
1112
+ _(aten, mm) \
1113
+ _(aten, mode) \
1114
+ _(aten, moveaxis) \
1115
+ _(aten, movedim) \
1116
+ _(aten, mps_convolution_backward) \
1117
+ _(aten, mps_convolution_transpose_backward) \
1118
+ _(aten, mse_loss) \
1119
+ _(aten, mse_loss_backward) \
1120
+ _(aten, msort) \
1121
+ _(aten, mul) \
1122
+ _(aten, mul_) \
1123
+ _(aten, multi_margin_loss) \
1124
+ _(aten, multi_margin_loss_backward) \
1125
+ _(aten, multilabel_margin_loss) \
1126
+ _(aten, multilabel_margin_loss_backward) \
1127
+ _(aten, multilabel_margin_loss_forward) \
1128
+ _(aten, multinomial) \
1129
+ _(aten, multiply) \
1130
+ _(aten, multiply_) \
1131
+ _(aten, mv) \
1132
+ _(aten, mvlgamma) \
1133
+ _(aten, mvlgamma_) \
1134
+ _(aten, nan_to_num) \
1135
+ _(aten, nan_to_num_) \
1136
+ _(aten, nanmean) \
1137
+ _(aten, nanmedian) \
1138
+ _(aten, nanquantile) \
1139
+ _(aten, nansum) \
1140
+ _(aten, narrow) \
1141
+ _(aten, narrow_copy) \
1142
+ _(aten, native_batch_norm) \
1143
+ _(aten, native_batch_norm_backward) \
1144
+ _(aten, native_channel_shuffle) \
1145
+ _(aten, native_dropout) \
1146
+ _(aten, native_dropout_backward) \
1147
+ _(aten, native_group_norm) \
1148
+ _(aten, native_group_norm_backward) \
1149
+ _(aten, native_layer_norm) \
1150
+ _(aten, native_layer_norm_backward) \
1151
+ _(aten, native_norm) \
1152
+ _(aten, ne) \
1153
+ _(aten, ne_) \
1154
+ _(aten, neg) \
1155
+ _(aten, neg_) \
1156
+ _(aten, negative) \
1157
+ _(aten, negative_) \
1158
+ _(aten, nested_to_padded_tensor) \
1159
+ _(aten, new_empty) \
1160
+ _(aten, new_empty_strided) \
1161
+ _(aten, new_full) \
1162
+ _(aten, new_ones) \
1163
+ _(aten, new_zeros) \
1164
+ _(aten, nextafter) \
1165
+ _(aten, nextafter_) \
1166
+ _(aten, nll_loss) \
1167
+ _(aten, nll_loss2d) \
1168
+ _(aten, nll_loss2d_backward) \
1169
+ _(aten, nll_loss2d_forward) \
1170
+ _(aten, nll_loss_backward) \
1171
+ _(aten, nll_loss_forward) \
1172
+ _(aten, nll_loss_nd) \
1173
+ _(aten, nonzero) \
1174
+ _(aten, nonzero_numpy) \
1175
+ _(aten, nonzero_static) \
1176
+ _(aten, norm) \
1177
+ _(aten, norm_except_dim) \
1178
+ _(aten, normal) \
1179
+ _(aten, normal_) \
1180
+ _(aten, normal_functional) \
1181
+ _(aten, not_equal) \
1182
+ _(aten, not_equal_) \
1183
+ _(aten, nuclear_norm) \
1184
+ _(aten, numpy_T) \
1185
+ _(aten, one_hot) \
1186
+ _(aten, ones) \
1187
+ _(aten, ones_like) \
1188
+ _(aten, orgqr) \
1189
+ _(aten, ormqr) \
1190
+ _(aten, outer) \
1191
+ _(aten, output_nr) \
1192
+ _(aten, pad) \
1193
+ _(aten, pad_sequence) \
1194
+ _(aten, pairwise_distance) \
1195
+ _(aten, pdist) \
1196
+ _(aten, permute) \
1197
+ _(aten, permute_copy) \
1198
+ _(aten, pin_memory) \
1199
+ _(aten, pinverse) \
1200
+ _(aten, pixel_shuffle) \
1201
+ _(aten, pixel_unshuffle) \
1202
+ _(aten, poisson) \
1203
+ _(aten, poisson_nll_loss) \
1204
+ _(aten, polar) \
1205
+ _(aten, polygamma) \
1206
+ _(aten, polygamma_) \
1207
+ _(aten, positive) \
1208
+ _(aten, pow) \
1209
+ _(aten, pow_) \
1210
+ _(aten, prelu) \
1211
+ _(aten, prod) \
1212
+ _(aten, promote_types) \
1213
+ _(aten, put) \
1214
+ _(aten, put_) \
1215
+ _(aten, q_per_channel_axis) \
1216
+ _(aten, q_per_channel_scales) \
1217
+ _(aten, q_per_channel_zero_points) \
1218
+ _(aten, q_scale) \
1219
+ _(aten, q_zero_point) \
1220
+ _(aten, qr) \
1221
+ _(aten, qscheme) \
1222
+ _(aten, quantile) \
1223
+ _(aten, quantize_per_channel) \
1224
+ _(aten, quantize_per_tensor) \
1225
+ _(aten, quantize_per_tensor_dynamic) \
1226
+ _(aten, quantized_batch_norm) \
1227
+ _(aten, quantized_gru_cell) \
1228
+ _(aten, quantized_lstm_cell) \
1229
+ _(aten, quantized_max_pool1d) \
1230
+ _(aten, quantized_max_pool2d) \
1231
+ _(aten, quantized_max_pool3d) \
1232
+ _(aten, quantized_rnn_relu_cell) \
1233
+ _(aten, quantized_rnn_tanh_cell) \
1234
+ _(aten, rad2deg) \
1235
+ _(aten, rad2deg_) \
1236
+ _(aten, rand) \
1237
+ _(aten, rand_like) \
1238
+ _(aten, randint) \
1239
+ _(aten, randint_like) \
1240
+ _(aten, randn) \
1241
+ _(aten, randn_like) \
1242
+ _(aten, random) \
1243
+ _(aten, random_) \
1244
+ _(aten, randperm) \
1245
+ _(aten, range) \
1246
+ _(aten, ravel) \
1247
+ _(aten, real) \
1248
+ _(aten, reciprocal) \
1249
+ _(aten, reciprocal_) \
1250
+ _(aten, record_stream) \
1251
+ _(aten, refine_names) \
1252
+ _(aten, reflection_pad1d) \
1253
+ _(aten, reflection_pad1d_backward) \
1254
+ _(aten, reflection_pad2d) \
1255
+ _(aten, reflection_pad2d_backward) \
1256
+ _(aten, reflection_pad3d) \
1257
+ _(aten, reflection_pad3d_backward) \
1258
+ _(aten, relu) \
1259
+ _(aten, relu6) \
1260
+ _(aten, relu6_) \
1261
+ _(aten, relu_) \
1262
+ _(aten, remainder) \
1263
+ _(aten, remainder_) \
1264
+ _(aten, rename) \
1265
+ _(aten, rename_) \
1266
+ _(aten, renorm) \
1267
+ _(aten, renorm_) \
1268
+ _(aten, repeat) \
1269
+ _(aten, repeat_interleave) \
1270
+ _(aten, replication_pad1d) \
1271
+ _(aten, replication_pad1d_backward) \
1272
+ _(aten, replication_pad2d) \
1273
+ _(aten, replication_pad2d_backward) \
1274
+ _(aten, replication_pad3d) \
1275
+ _(aten, replication_pad3d_backward) \
1276
+ _(aten, requires_grad) \
1277
+ _(aten, requires_grad_) \
1278
+ _(aten, reshape) \
1279
+ _(aten, reshape_as) \
1280
+ _(aten, resize) \
1281
+ _(aten, resize_) \
1282
+ _(aten, resize_as) \
1283
+ _(aten, resize_as_) \
1284
+ _(aten, resize_as_sparse) \
1285
+ _(aten, resize_as_sparse_) \
1286
+ _(aten, resolve_conj) \
1287
+ _(aten, resolve_neg) \
1288
+ _(aten, result_type) \
1289
+ _(aten, retain_grad) \
1290
+ _(aten, retains_grad) \
1291
+ _(aten, rms_norm) \
1292
+ _(aten, rnn_relu) \
1293
+ _(aten, rnn_relu_cell) \
1294
+ _(aten, rnn_tanh) \
1295
+ _(aten, rnn_tanh_cell) \
1296
+ _(aten, roll) \
1297
+ _(aten, rot90) \
1298
+ _(aten, round) \
1299
+ _(aten, round_) \
1300
+ _(aten, row_indices) \
1301
+ _(aten, row_indices_copy) \
1302
+ _(aten, row_stack) \
1303
+ _(aten, rrelu) \
1304
+ _(aten, rrelu_) \
1305
+ _(aten, rrelu_with_noise) \
1306
+ _(aten, rrelu_with_noise_) \
1307
+ _(aten, rrelu_with_noise_backward) \
1308
+ _(aten, rshift) \
1309
+ _(aten, rsqrt) \
1310
+ _(aten, rsqrt_) \
1311
+ _(aten, rsub) \
1312
+ _(aten, scalar_tensor) \
1313
+ _(aten, scaled_dot_product_attention) \
1314
+ _(aten, scatter) \
1315
+ _(aten, scatter_) \
1316
+ _(aten, scatter_add) \
1317
+ _(aten, scatter_add_) \
1318
+ _(aten, scatter_reduce) \
1319
+ _(aten, scatter_reduce_) \
1320
+ _(aten, searchsorted) \
1321
+ _(aten, segment_reduce) \
1322
+ _(aten, select) \
1323
+ _(aten, select_backward) \
1324
+ _(aten, select_copy) \
1325
+ _(aten, select_scatter) \
1326
+ _(aten, selu) \
1327
+ _(aten, selu_) \
1328
+ _(aten, set) \
1329
+ _(aten, set_) \
1330
+ _(aten, set_data) \
1331
+ _(aten, sgn) \
1332
+ _(aten, sgn_) \
1333
+ _(aten, sigmoid) \
1334
+ _(aten, sigmoid_) \
1335
+ _(aten, sigmoid_backward) \
1336
+ _(aten, sign) \
1337
+ _(aten, sign_) \
1338
+ _(aten, signbit) \
1339
+ _(aten, silu) \
1340
+ _(aten, silu_) \
1341
+ _(aten, silu_backward) \
1342
+ _(aten, sin) \
1343
+ _(aten, sin_) \
1344
+ _(aten, sinc) \
1345
+ _(aten, sinc_) \
1346
+ _(aten, sinh) \
1347
+ _(aten, sinh_) \
1348
+ _(aten, size) \
1349
+ _(aten, slice) \
1350
+ _(aten, slice_backward) \
1351
+ _(aten, slice_copy) \
1352
+ _(aten, slice_inverse) \
1353
+ _(aten, slice_scatter) \
1354
+ _(aten, slogdet) \
1355
+ _(aten, slow_conv3d) \
1356
+ _(aten, slow_conv3d_forward) \
1357
+ _(aten, slow_conv_dilated2d) \
1358
+ _(aten, slow_conv_dilated3d) \
1359
+ _(aten, slow_conv_transpose2d) \
1360
+ _(aten, slow_conv_transpose3d) \
1361
+ _(aten, smm) \
1362
+ _(aten, smooth_l1_loss) \
1363
+ _(aten, smooth_l1_loss_backward) \
1364
+ _(aten, soft_margin_loss) \
1365
+ _(aten, soft_margin_loss_backward) \
1366
+ _(aten, softmax) \
1367
+ _(aten, softplus) \
1368
+ _(aten, softplus_backward) \
1369
+ _(aten, softshrink) \
1370
+ _(aten, softshrink_backward) \
1371
+ _(aten, sort) \
1372
+ _(aten, sparse_bsc_tensor) \
1373
+ _(aten, sparse_bsr_tensor) \
1374
+ _(aten, sparse_compressed_tensor) \
1375
+ _(aten, sparse_coo_tensor) \
1376
+ _(aten, sparse_csc_tensor) \
1377
+ _(aten, sparse_csr_tensor) \
1378
+ _(aten, sparse_dim) \
1379
+ _(aten, sparse_mask) \
1380
+ _(aten, sparse_resize) \
1381
+ _(aten, sparse_resize_) \
1382
+ _(aten, sparse_resize_and_clear) \
1383
+ _(aten, sparse_resize_and_clear_) \
1384
+ _(aten, sparse_sampled_addmm) \
1385
+ _(aten, special_airy_ai) \
1386
+ _(aten, special_bessel_j0) \
1387
+ _(aten, special_bessel_j1) \
1388
+ _(aten, special_bessel_y0) \
1389
+ _(aten, special_bessel_y1) \
1390
+ _(aten, special_chebyshev_polynomial_t) \
1391
+ _(aten, special_chebyshev_polynomial_u) \
1392
+ _(aten, special_chebyshev_polynomial_v) \
1393
+ _(aten, special_chebyshev_polynomial_w) \
1394
+ _(aten, special_digamma) \
1395
+ _(aten, special_entr) \
1396
+ _(aten, special_erf) \
1397
+ _(aten, special_erfc) \
1398
+ _(aten, special_erfcx) \
1399
+ _(aten, special_erfinv) \
1400
+ _(aten, special_exp2) \
1401
+ _(aten, special_expit) \
1402
+ _(aten, special_expm1) \
1403
+ _(aten, special_gammainc) \
1404
+ _(aten, special_gammaincc) \
1405
+ _(aten, special_gammaln) \
1406
+ _(aten, special_hermite_polynomial_h) \
1407
+ _(aten, special_hermite_polynomial_he) \
1408
+ _(aten, special_i0) \
1409
+ _(aten, special_i0e) \
1410
+ _(aten, special_i1) \
1411
+ _(aten, special_i1e) \
1412
+ _(aten, special_laguerre_polynomial_l) \
1413
+ _(aten, special_legendre_polynomial_p) \
1414
+ _(aten, special_log1p) \
1415
+ _(aten, special_log_ndtr) \
1416
+ _(aten, special_log_softmax) \
1417
+ _(aten, special_logit) \
1418
+ _(aten, special_logsumexp) \
1419
+ _(aten, special_modified_bessel_i0) \
1420
+ _(aten, special_modified_bessel_i1) \
1421
+ _(aten, special_modified_bessel_k0) \
1422
+ _(aten, special_modified_bessel_k1) \
1423
+ _(aten, special_multigammaln) \
1424
+ _(aten, special_ndtr) \
1425
+ _(aten, special_ndtri) \
1426
+ _(aten, special_polygamma) \
1427
+ _(aten, special_psi) \
1428
+ _(aten, special_round) \
1429
+ _(aten, special_scaled_modified_bessel_k0) \
1430
+ _(aten, special_scaled_modified_bessel_k1) \
1431
+ _(aten, special_shifted_chebyshev_polynomial_t) \
1432
+ _(aten, special_shifted_chebyshev_polynomial_u) \
1433
+ _(aten, special_shifted_chebyshev_polynomial_v) \
1434
+ _(aten, special_shifted_chebyshev_polynomial_w) \
1435
+ _(aten, special_sinc) \
1436
+ _(aten, special_softmax) \
1437
+ _(aten, special_spherical_bessel_j0) \
1438
+ _(aten, special_xlog1py) \
1439
+ _(aten, special_xlogy) \
1440
+ _(aten, special_zeta) \
1441
+ _(aten, split) \
1442
+ _(aten, split_copy) \
1443
+ _(aten, split_with_sizes) \
1444
+ _(aten, split_with_sizes_copy) \
1445
+ _(aten, sqrt) \
1446
+ _(aten, sqrt_) \
1447
+ _(aten, square) \
1448
+ _(aten, square_) \
1449
+ _(aten, squeeze) \
1450
+ _(aten, squeeze_) \
1451
+ _(aten, squeeze_copy) \
1452
+ _(aten, sspaddmm) \
1453
+ _(aten, stack) \
1454
+ _(aten, std) \
1455
+ _(aten, std_mean) \
1456
+ _(aten, stft) \
1457
+ _(aten, stride) \
1458
+ _(aten, sub) \
1459
+ _(aten, sub_) \
1460
+ _(aten, subtract) \
1461
+ _(aten, subtract_) \
1462
+ _(aten, sum) \
1463
+ _(aten, sum_to_size) \
1464
+ _(aten, svd) \
1465
+ _(aten, swapaxes) \
1466
+ _(aten, swapaxes_) \
1467
+ _(aten, swapdims) \
1468
+ _(aten, swapdims_) \
1469
+ _(aten, sym_constrain_range) \
1470
+ _(aten, sym_constrain_range_for_size) \
1471
+ _(aten, sym_numel) \
1472
+ _(aten, sym_size) \
1473
+ _(aten, sym_storage_offset) \
1474
+ _(aten, sym_stride) \
1475
+ _(aten, t) \
1476
+ _(aten, t_) \
1477
+ _(aten, t_copy) \
1478
+ _(aten, take) \
1479
+ _(aten, take_along_dim) \
1480
+ _(aten, tan) \
1481
+ _(aten, tan_) \
1482
+ _(aten, tanh) \
1483
+ _(aten, tanh_) \
1484
+ _(aten, tanh_backward) \
1485
+ _(aten, tensor_split) \
1486
+ _(aten, tensordot) \
1487
+ _(aten, thnn_conv2d) \
1488
+ _(aten, threshold) \
1489
+ _(aten, threshold_) \
1490
+ _(aten, threshold_backward) \
1491
+ _(aten, tile) \
1492
+ _(aten, to) \
1493
+ _(aten, to_dense) \
1494
+ _(aten, to_dense_backward) \
1495
+ _(aten, to_mkldnn) \
1496
+ _(aten, to_mkldnn_backward) \
1497
+ _(aten, to_padded_tensor) \
1498
+ _(aten, to_sparse) \
1499
+ _(aten, to_sparse_bsc) \
1500
+ _(aten, to_sparse_bsr) \
1501
+ _(aten, to_sparse_csc) \
1502
+ _(aten, to_sparse_csr) \
1503
+ _(aten, topk) \
1504
+ _(aten, trace) \
1505
+ _(aten, trace_backward) \
1506
+ _(aten, transpose) \
1507
+ _(aten, transpose_) \
1508
+ _(aten, transpose_copy) \
1509
+ _(aten, trapezoid) \
1510
+ _(aten, trapz) \
1511
+ _(aten, triangular_solve) \
1512
+ _(aten, tril) \
1513
+ _(aten, tril_) \
1514
+ _(aten, tril_indices) \
1515
+ _(aten, triplet_margin_loss) \
1516
+ _(aten, triu) \
1517
+ _(aten, triu_) \
1518
+ _(aten, triu_indices) \
1519
+ _(aten, true_divide) \
1520
+ _(aten, true_divide_) \
1521
+ _(aten, trunc) \
1522
+ _(aten, trunc_) \
1523
+ _(aten, type_as) \
1524
+ _(aten, unbind) \
1525
+ _(aten, unbind_copy) \
1526
+ _(aten, unflatten) \
1527
+ _(aten, unflatten_dense_tensors) \
1528
+ _(aten, unfold) \
1529
+ _(aten, unfold_backward) \
1530
+ _(aten, unfold_copy) \
1531
+ _(aten, uniform) \
1532
+ _(aten, uniform_) \
1533
+ _(aten, unique_consecutive) \
1534
+ _(aten, unique_dim) \
1535
+ _(aten, unique_dim_consecutive) \
1536
+ _(aten, unsafe_chunk) \
1537
+ _(aten, unsafe_split) \
1538
+ _(aten, unsafe_split_with_sizes) \
1539
+ _(aten, unsqueeze) \
1540
+ _(aten, unsqueeze_) \
1541
+ _(aten, unsqueeze_copy) \
1542
+ _(aten, upsample_bicubic2d) \
1543
+ _(aten, upsample_bicubic2d_backward) \
1544
+ _(aten, upsample_bilinear2d) \
1545
+ _(aten, upsample_bilinear2d_backward) \
1546
+ _(aten, upsample_linear1d) \
1547
+ _(aten, upsample_linear1d_backward) \
1548
+ _(aten, upsample_nearest1d) \
1549
+ _(aten, upsample_nearest1d_backward) \
1550
+ _(aten, upsample_nearest2d) \
1551
+ _(aten, upsample_nearest2d_backward) \
1552
+ _(aten, upsample_nearest3d) \
1553
+ _(aten, upsample_nearest3d_backward) \
1554
+ _(aten, upsample_trilinear3d) \
1555
+ _(aten, upsample_trilinear3d_backward) \
1556
+ _(aten, value_selecting_reduction_backward) \
1557
+ _(aten, values) \
1558
+ _(aten, values_copy) \
1559
+ _(aten, vander) \
1560
+ _(aten, var) \
1561
+ _(aten, var_mean) \
1562
+ _(aten, vdot) \
1563
+ _(aten, view) \
1564
+ _(aten, view_as) \
1565
+ _(aten, view_as_complex) \
1566
+ _(aten, view_as_complex_copy) \
1567
+ _(aten, view_as_real) \
1568
+ _(aten, view_as_real_copy) \
1569
+ _(aten, view_copy) \
1570
+ _(aten, vsplit) \
1571
+ _(aten, vstack) \
1572
+ _(aten, where) \
1573
+ _(aten, xlogy) \
1574
+ _(aten, xlogy_) \
1575
+ _(aten, zero) \
1576
+ _(aten, zero_) \
1577
+ _(aten, zeros) \
1578
+ _(aten, zeros_like)
1579
+
1580
+ #define FORALL_ATTR_BASE_SYMBOLS(_) \
1581
+ _(attr, A) \
1582
+ _(attr, B) \
1583
+ _(attr, C) \
1584
+ _(attr, H) \
1585
+ _(attr, HxW) \
1586
+ _(attr, K) \
1587
+ _(attr, L) \
1588
+ _(attr, LD) \
1589
+ _(attr, LU) \
1590
+ _(attr, LU_data) \
1591
+ _(attr, LU_pivots) \
1592
+ _(attr, M) \
1593
+ _(attr, N) \
1594
+ _(attr, P) \
1595
+ _(attr, Q) \
1596
+ _(attr, R) \
1597
+ _(attr, S) \
1598
+ _(attr, U) \
1599
+ _(attr, UPLO) \
1600
+ _(attr, V) \
1601
+ _(attr, Vh) \
1602
+ _(attr, W) \
1603
+ _(attr, X) \
1604
+ _(attr, a) \
1605
+ _(attr, abs) \
1606
+ _(attr, accumulate) \
1607
+ _(attr, accumulate_matches) \
1608
+ _(attr, activation) \
1609
+ _(attr, addends) \
1610
+ _(attr, adjoint) \
1611
+ _(attr, alg_id) \
1612
+ _(attr, algorithm) \
1613
+ _(attr, alibi_slopes) \
1614
+ _(attr, align_corners) \
1615
+ _(attr, allow_tf32) \
1616
+ _(attr, alpha) \
1617
+ _(attr, amsgrad) \
1618
+ _(attr, anchor) \
1619
+ _(attr, angle) \
1620
+ _(attr, any) \
1621
+ _(attr, api_name) \
1622
+ _(attr, append) \
1623
+ _(attr, approximate) \
1624
+ _(attr, arg1) \
1625
+ _(attr, arg2) \
1626
+ _(attr, arg3) \
1627
+ _(attr, arg_out) \
1628
+ _(attr, assert_msg) \
1629
+ _(attr, assume_unique) \
1630
+ _(attr, atol) \
1631
+ _(attr, attn_bias) \
1632
+ _(attr, attn_mask) \
1633
+ _(attr, average_attn_weights) \
1634
+ _(attr, averaging_const) \
1635
+ _(attr, aweights) \
1636
+ _(attr, axis) \
1637
+ _(attr, axis0) \
1638
+ _(attr, axis1) \
1639
+ _(attr, b) \
1640
+ _(attr, b_hh) \
1641
+ _(attr, b_ih) \
1642
+ _(attr, bag_size) \
1643
+ _(attr, base) \
1644
+ _(attr, batch1) \
1645
+ _(attr, batch2) \
1646
+ _(attr, batch_dim) \
1647
+ _(attr, batch_first) \
1648
+ _(attr, batch_size) \
1649
+ _(attr, batch_sizes) \
1650
+ _(attr, benchmark) \
1651
+ _(attr, beta) \
1652
+ _(attr, beta1) \
1653
+ _(attr, beta2) \
1654
+ _(attr, bias) \
1655
+ _(attr, bias_defined) \
1656
+ _(attr, bias_g) \
1657
+ _(attr, bias_requires_grad) \
1658
+ _(attr, bias_sizes) \
1659
+ _(attr, bidirectional) \
1660
+ _(attr, bin_edges) \
1661
+ _(attr, bins) \
1662
+ _(attr, bit_width) \
1663
+ _(attr, blank) \
1664
+ _(attr, blocksize) \
1665
+ _(attr, boundaries) \
1666
+ _(attr, buffer) \
1667
+ _(attr, ccol_indices) \
1668
+ _(attr, cdim) \
1669
+ _(attr, cdist) \
1670
+ _(attr, ceil_mode) \
1671
+ _(attr, cell_state_fwd) \
1672
+ _(attr, center) \
1673
+ _(attr, ch_axis) \
1674
+ _(attr, check_errors) \
1675
+ _(attr, chunks) \
1676
+ _(attr, coalesced) \
1677
+ _(attr, coefficients) \
1678
+ _(attr, col) \
1679
+ _(attr, col_indices) \
1680
+ _(attr, col_offsets) \
1681
+ _(attr, col_offsets_hh) \
1682
+ _(attr, col_offsets_ih) \
1683
+ _(attr, compressed_A) \
1684
+ _(attr, compressed_idx) \
1685
+ _(attr, compressed_indices) \
1686
+ _(attr, compressed_indices_dtype) \
1687
+ _(attr, compute_log_sumexp) \
1688
+ _(attr, compute_mode) \
1689
+ _(attr, compute_uv) \
1690
+ _(attr, compute_v) \
1691
+ _(attr, condition) \
1692
+ _(attr, copy) \
1693
+ _(attr, correction) \
1694
+ _(attr, count) \
1695
+ _(attr, count_include_pad) \
1696
+ _(attr, counts) \
1697
+ _(attr, cpu_dtype) \
1698
+ _(attr, cpu_enabled) \
1699
+ _(attr, cpu_nested_shape_example) \
1700
+ _(attr, create_graph) \
1701
+ _(attr, crow_indices) \
1702
+ _(attr, cu_seqlens_k) \
1703
+ _(attr, cu_seqlens_q) \
1704
+ _(attr, cuda_dtype) \
1705
+ _(attr, cuda_enabled) \
1706
+ _(attr, cudnn_enable) \
1707
+ _(attr, cudnn_enabled) \
1708
+ _(attr, cum_seq_k) \
1709
+ _(attr, cum_seq_q) \
1710
+ _(attr, custom_mask_type) \
1711
+ _(attr, cx) \
1712
+ _(attr, cx_) \
1713
+ _(attr, cx_tmp) \
1714
+ _(attr, cy) \
1715
+ _(attr, cy_) \
1716
+ _(attr, d) \
1717
+ _(attr, dampening) \
1718
+ _(attr, data) \
1719
+ _(attr, decimals) \
1720
+ _(attr, delta) \
1721
+ _(attr, dense) \
1722
+ _(attr, dense_B) \
1723
+ _(attr, dense_dim) \
1724
+ _(attr, density) \
1725
+ _(attr, dep_token) \
1726
+ _(attr, descending) \
1727
+ _(attr, destination) \
1728
+ _(attr, deterministic) \
1729
+ _(attr, device) \
1730
+ _(attr, device_index) \
1731
+ _(attr, dgrad_glu) \
1732
+ _(attr, diagonal) \
1733
+ _(attr, diagonals) \
1734
+ _(attr, dilation) \
1735
+ _(attr, dim) \
1736
+ _(attr, dim0) \
1737
+ _(attr, dim1) \
1738
+ _(attr, dim2) \
1739
+ _(attr, dimension) \
1740
+ _(attr, dims) \
1741
+ _(attr, dims_other) \
1742
+ _(attr, dims_self) \
1743
+ _(attr, divisor_override) \
1744
+ _(attr, downscale_factor) \
1745
+ _(attr, driver) \
1746
+ _(attr, dropout) \
1747
+ _(attr, dropout_mask) \
1748
+ _(attr, dropout_p) \
1749
+ _(attr, dropout_seed) \
1750
+ _(attr, dropout_state) \
1751
+ _(attr, dst) \
1752
+ _(attr, dtype) \
1753
+ _(attr, dual) \
1754
+ _(attr, dummy) \
1755
+ _(attr, dx) \
1756
+ _(attr, edge_order) \
1757
+ _(attr, eigenvalues) \
1758
+ _(attr, eigenvectors) \
1759
+ _(attr, eigvals) \
1760
+ _(attr, eigvecs) \
1761
+ _(attr, element) \
1762
+ _(attr, elements) \
1763
+ _(attr, ellipsis_idx) \
1764
+ _(attr, embed_dim) \
1765
+ _(attr, enable_gqa) \
1766
+ _(attr, end) \
1767
+ _(attr, end_dim) \
1768
+ _(attr, eps) \
1769
+ _(attr, epsilon) \
1770
+ _(attr, equal_nan) \
1771
+ _(attr, equation) \
1772
+ _(attr, exp_avg_sqs) \
1773
+ _(attr, exp_avgs) \
1774
+ _(attr, expand1) \
1775
+ _(attr, expand2) \
1776
+ _(attr, expand3) \
1777
+ _(attr, exponent) \
1778
+ _(attr, exponential_average_factor) \
1779
+ _(attr, fake_quant_enabled) \
1780
+ _(attr, fake_quant_on) \
1781
+ _(attr, ffn_bias_1) \
1782
+ _(attr, ffn_bias_2) \
1783
+ _(attr, ffn_weight_1) \
1784
+ _(attr, ffn_weight_2) \
1785
+ _(attr, filename) \
1786
+ _(attr, fill) \
1787
+ _(attr, fill_value) \
1788
+ _(attr, flat) \
1789
+ _(attr, forward) \
1790
+ _(attr, found_inf) \
1791
+ _(attr, from) \
1792
+ _(attr, from_) \
1793
+ _(attr, full) \
1794
+ _(attr, full_matrices) \
1795
+ _(attr, fuse_transform_0213) \
1796
+ _(attr, fweights) \
1797
+ _(attr, g) \
1798
+ _(attr, gO) \
1799
+ _(attr, generator) \
1800
+ _(attr, ggI) \
1801
+ _(attr, ggW) \
1802
+ _(attr, ggb) \
1803
+ _(attr, glu) \
1804
+ _(attr, grad) \
1805
+ _(attr, grad_bias) \
1806
+ _(attr, grad_cy) \
1807
+ _(attr, grad_factor) \
1808
+ _(attr, grad_glu) \
1809
+ _(attr, grad_hy) \
1810
+ _(attr, grad_in) \
1811
+ _(attr, grad_input) \
1812
+ _(attr, grad_input_mask) \
1813
+ _(attr, grad_out) \
1814
+ _(attr, grad_out_) \
1815
+ _(attr, grad_output) \
1816
+ _(attr, grad_scale) \
1817
+ _(attr, grad_w) \
1818
+ _(attr, grad_weight) \
1819
+ _(attr, grad_x) \
1820
+ _(attr, grad_y) \
1821
+ _(attr, gradient) \
1822
+ _(attr, grads) \
1823
+ _(attr, grid) \
1824
+ _(attr, group) \
1825
+ _(attr, groups) \
1826
+ _(attr, growth_interval) \
1827
+ _(attr, growth_tracker) \
1828
+ _(attr, half_to_float) \
1829
+ _(attr, has_bias) \
1830
+ _(attr, has_biases) \
1831
+ _(attr, hermitian) \
1832
+ _(attr, hidden_bias) \
1833
+ _(attr, hidden_gates) \
1834
+ _(attr, hidden_size) \
1835
+ _(attr, high) \
1836
+ _(attr, hist) \
1837
+ _(attr, hop_length) \
1838
+ _(attr, hx) \
1839
+ _(attr, hx_) \
1840
+ _(attr, hy_) \
1841
+ _(attr, i1) \
1842
+ _(attr, i2) \
1843
+ _(attr, i3) \
1844
+ _(attr, ignore_index) \
1845
+ _(attr, imag) \
1846
+ _(attr, impl_index) \
1847
+ _(attr, implicit) \
1848
+ _(attr, include_last_offset) \
1849
+ _(attr, include_self) \
1850
+ _(attr, increasing) \
1851
+ _(attr, ind) \
1852
+ _(attr, index) \
1853
+ _(attr, index_dtype) \
1854
+ _(attr, indexing) \
1855
+ _(attr, indices) \
1856
+ _(attr, info) \
1857
+ _(attr, initial) \
1858
+ _(attr, innerKTiles) \
1859
+ _(attr, input) \
1860
+ _(attr, input1) \
1861
+ _(attr, input2) \
1862
+ _(attr, input3) \
1863
+ _(attr, input_bias) \
1864
+ _(attr, input_dtype) \
1865
+ _(attr, input_g) \
1866
+ _(attr, input_gates) \
1867
+ _(attr, input_lengths) \
1868
+ _(attr, input_scale) \
1869
+ _(attr, input_size) \
1870
+ _(attr, input_sizes) \
1871
+ _(attr, input_zero_point) \
1872
+ _(attr, inputs) \
1873
+ _(attr, interpolation) \
1874
+ _(attr, interpolation_mode) \
1875
+ _(attr, inv_scale) \
1876
+ _(attr, inverse) \
1877
+ _(attr, invert) \
1878
+ _(attr, invstd) \
1879
+ _(attr, is_causal) \
1880
+ _(attr, is_coalesced) \
1881
+ _(attr, is_crow) \
1882
+ _(attr, is_first_step) \
1883
+ _(attr, is_matrix) \
1884
+ _(attr, is_result) \
1885
+ _(attr, is_target) \
1886
+ _(attr, k) \
1887
+ _(attr, keepdim) \
1888
+ _(attr, kernel_size) \
1889
+ _(attr, key) \
1890
+ _(attr, label_smoothing) \
1891
+ _(attr, lambd) \
1892
+ _(attr, largest) \
1893
+ _(attr, last_dim_size) \
1894
+ _(attr, layersOutputs) \
1895
+ _(attr, layout) \
1896
+ _(attr, left) \
1897
+ _(attr, length) \
1898
+ _(attr, lengths) \
1899
+ _(attr, level) \
1900
+ _(attr, like) \
1901
+ _(attr, list) \
1902
+ _(attr, log_alpha) \
1903
+ _(attr, log_input) \
1904
+ _(attr, log_probs) \
1905
+ _(attr, log_target) \
1906
+ _(attr, logabsdet) \
1907
+ _(attr, logsumexp) \
1908
+ _(attr, low) \
1909
+ _(attr, lower) \
1910
+ _(attr, lr) \
1911
+ _(attr, lr_decay) \
1912
+ _(attr, ltm) \
1913
+ _(attr, m) \
1914
+ _(attr, mantissa) \
1915
+ _(attr, margin) \
1916
+ _(attr, mask) \
1917
+ _(attr, mask_check) \
1918
+ _(attr, mask_type) \
1919
+ _(attr, masked_grad) \
1920
+ _(attr, mat) \
1921
+ _(attr, mat1) \
1922
+ _(attr, mat1_meta) \
1923
+ _(attr, mat2) \
1924
+ _(attr, matrices) \
1925
+ _(attr, max) \
1926
+ _(attr, max_exp_avg_sqs) \
1927
+ _(attr, max_k) \
1928
+ _(attr, max_lengths) \
1929
+ _(attr, max_norm) \
1930
+ _(attr, max_q) \
1931
+ _(attr, max_seqlen) \
1932
+ _(attr, max_seqlen_k) \
1933
+ _(attr, max_seqlen_q) \
1934
+ _(attr, max_size) \
1935
+ _(attr, max_val) \
1936
+ _(attr, max_values) \
1937
+ _(attr, maximize) \
1938
+ _(attr, maximum_indices) \
1939
+ _(attr, maxnorm) \
1940
+ _(attr, mean) \
1941
+ _(attr, median) \
1942
+ _(attr, memory_format) \
1943
+ _(attr, meta) \
1944
+ _(attr, min) \
1945
+ _(attr, min_indices) \
1946
+ _(attr, min_seqlen) \
1947
+ _(attr, min_val) \
1948
+ _(attr, minlength) \
1949
+ _(attr, mode) \
1950
+ _(attr, momentum) \
1951
+ _(attr, momentum_buffer_list) \
1952
+ _(attr, n) \
1953
+ _(attr, n_bins) \
1954
+ _(attr, n_fft) \
1955
+ _(attr, names) \
1956
+ _(attr, nan) \
1957
+ _(attr, need_weights) \
1958
+ _(attr, neg_log_likelihood) \
1959
+ _(attr, negative) \
1960
+ _(attr, negative_slope) \
1961
+ _(attr, neginf) \
1962
+ _(attr, nested_size) \
1963
+ _(attr, nested_strides) \
1964
+ _(attr, nesterov) \
1965
+ _(attr, new_data) \
1966
+ _(attr, nnz) \
1967
+ _(attr, noise) \
1968
+ _(attr, non_blocking) \
1969
+ _(attr, norm) \
1970
+ _(attr, norm_bias_1) \
1971
+ _(attr, norm_bias_2) \
1972
+ _(attr, norm_first) \
1973
+ _(attr, norm_type) \
1974
+ _(attr, norm_weight_1) \
1975
+ _(attr, norm_weight_2) \
1976
+ _(attr, normalization) \
1977
+ _(attr, normalized) \
1978
+ _(attr, normalized_shape) \
1979
+ _(attr, nt_example) \
1980
+ _(attr, num_chunks) \
1981
+ _(attr, num_classes) \
1982
+ _(attr, num_generated) \
1983
+ _(attr, num_groups) \
1984
+ _(attr, num_head) \
1985
+ _(attr, num_heads) \
1986
+ _(attr, num_layers) \
1987
+ _(attr, num_parallel) \
1988
+ _(attr, num_samples) \
1989
+ _(attr, num_splits_key) \
1990
+ _(attr, num_weights) \
1991
+ _(attr, numel) \
1992
+ _(attr, observer_on) \
1993
+ _(attr, offset) \
1994
+ _(attr, offset2bag) \
1995
+ _(attr, offsets) \
1996
+ _(attr, onesided) \
1997
+ _(attr, ord) \
1998
+ _(attr, order) \
1999
+ _(attr, other) \
2000
+ _(attr, out) \
2001
+ _(attr, out0) \
2002
+ _(attr, out1) \
2003
+ _(attr, out2) \
2004
+ _(attr, out3) \
2005
+ _(attr, out4) \
2006
+ _(attr, out5) \
2007
+ _(attr, out6) \
2008
+ _(attr, out_channel) \
2009
+ _(attr, out_dim) \
2010
+ _(attr, out_dtype) \
2011
+ _(attr, out_int32) \
2012
+ _(attr, outdim) \
2013
+ _(attr, output) \
2014
+ _(attr, output_mask) \
2015
+ _(attr, output_padding) \
2016
+ _(attr, output_scale) \
2017
+ _(attr, output_size) \
2018
+ _(attr, output_zero_point) \
2019
+ _(attr, p) \
2020
+ _(attr, packed) \
2021
+ _(attr, packed_hh) \
2022
+ _(attr, packed_ih) \
2023
+ _(attr, packed_weight) \
2024
+ _(attr, pad) \
2025
+ _(attr, pad_mode) \
2026
+ _(attr, padded) \
2027
+ _(attr, padding) \
2028
+ _(attr, padding_idx) \
2029
+ _(attr, padding_mode) \
2030
+ _(attr, padding_side) \
2031
+ _(attr, padding_value) \
2032
+ _(attr, params) \
2033
+ _(attr, path) \
2034
+ _(attr, pdist) \
2035
+ _(attr, per_row_fake_quant) \
2036
+ _(attr, per_sample_weights) \
2037
+ _(attr, periodic) \
2038
+ _(attr, philox_offset) \
2039
+ _(attr, philox_seed) \
2040
+ _(attr, physical_layout) \
2041
+ _(attr, pin_memory) \
2042
+ _(attr, pivot) \
2043
+ _(attr, pivots) \
2044
+ _(attr, plain_idx) \
2045
+ _(attr, plain_indices) \
2046
+ _(attr, pos_weight) \
2047
+ _(attr, posinf) \
2048
+ _(attr, positive) \
2049
+ _(attr, pow) \
2050
+ _(attr, prepend) \
2051
+ _(attr, primal) \
2052
+ _(attr, prob) \
2053
+ _(attr, proj_bias) \
2054
+ _(attr, proj_size) \
2055
+ _(attr, proj_weight) \
2056
+ _(attr, q) \
2057
+ _(attr, qGroupSize) \
2058
+ _(attr, qScaleAndZeros) \
2059
+ _(attr, qkv) \
2060
+ _(attr, qkv_bias) \
2061
+ _(attr, qkv_weight) \
2062
+ _(attr, qtensor) \
2063
+ _(attr, quant_max) \
2064
+ _(attr, quant_min) \
2065
+ _(attr, quasi) \
2066
+ _(attr, query) \
2067
+ _(attr, r) \
2068
+ _(attr, ragged_idx) \
2069
+ _(attr, random_samples) \
2070
+ _(attr, range) \
2071
+ _(attr, rank) \
2072
+ _(attr, ratio) \
2073
+ _(attr, rcond) \
2074
+ _(attr, real) \
2075
+ _(attr, reduce) \
2076
+ _(attr, reduce_range) \
2077
+ _(attr, reduction) \
2078
+ _(attr, repeats) \
2079
+ _(attr, replacement) \
2080
+ _(attr, requires_grad) \
2081
+ _(attr, reserve) \
2082
+ _(attr, reserveSpace) \
2083
+ _(attr, reservedSpace) \
2084
+ _(attr, residuals) \
2085
+ _(attr, result) \
2086
+ _(attr, retain_graph) \
2087
+ _(attr, return_complex) \
2088
+ _(attr, return_counts) \
2089
+ _(attr, return_debug_mask) \
2090
+ _(attr, return_inverse) \
2091
+ _(attr, reverse) \
2092
+ _(attr, right) \
2093
+ _(attr, rounding_mode) \
2094
+ _(attr, row) \
2095
+ _(attr, row_indices) \
2096
+ _(attr, rstd) \
2097
+ _(attr, rtol) \
2098
+ _(attr, running_max) \
2099
+ _(attr, running_mean) \
2100
+ _(attr, running_min) \
2101
+ _(attr, running_var) \
2102
+ _(attr, s) \
2103
+ _(attr, save_invstd) \
2104
+ _(attr, save_mean) \
2105
+ _(attr, save_var) \
2106
+ _(attr, save_var_transform) \
2107
+ _(attr, saved_g) \
2108
+ _(attr, saved_norms) \
2109
+ _(attr, saved_v) \
2110
+ _(attr, scalar) \
2111
+ _(attr, scalar1) \
2112
+ _(attr, scalar2) \
2113
+ _(attr, scalars) \
2114
+ _(attr, scale) \
2115
+ _(attr, scale_a) \
2116
+ _(attr, scale_b) \
2117
+ _(attr, scale_backoff_factor) \
2118
+ _(attr, scale_factors) \
2119
+ _(attr, scale_grad_by_freq) \
2120
+ _(attr, scale_growth_factor) \
2121
+ _(attr, scale_hh) \
2122
+ _(attr, scale_ih) \
2123
+ _(attr, scale_result) \
2124
+ _(attr, scales) \
2125
+ _(attr, scales_d) \
2126
+ _(attr, scales_h) \
2127
+ _(attr, scales_w) \
2128
+ _(attr, sections) \
2129
+ _(attr, seed) \
2130
+ _(attr, self) \
2131
+ _(attr, self_is_result) \
2132
+ _(attr, self_num_batch_dims) \
2133
+ _(attr, self_or_result) \
2134
+ _(attr, self_sizes) \
2135
+ _(attr, seqlen_k) \
2136
+ _(attr, sequences) \
2137
+ _(attr, seqused_k) \
2138
+ _(attr, shape) \
2139
+ _(attr, shared) \
2140
+ _(attr, shared_storage_dqdkdv) \
2141
+ _(attr, shifts) \
2142
+ _(attr, side) \
2143
+ _(attr, sigma) \
2144
+ _(attr, sign) \
2145
+ _(attr, singular_values) \
2146
+ _(attr, size) \
2147
+ _(attr, sizes) \
2148
+ _(attr, skip_first) \
2149
+ _(attr, sobolstate) \
2150
+ _(attr, solution) \
2151
+ _(attr, some) \
2152
+ _(attr, sorted) \
2153
+ _(attr, sorted_sequence) \
2154
+ _(attr, sorter) \
2155
+ _(attr, source) \
2156
+ _(attr, spacing) \
2157
+ _(attr, sparse) \
2158
+ _(attr, sparse_dim) \
2159
+ _(attr, sparse_grad) \
2160
+ _(attr, split_size) \
2161
+ _(attr, split_sizes) \
2162
+ _(attr, src) \
2163
+ _(attr, stable) \
2164
+ _(attr, start) \
2165
+ _(attr, start_dim) \
2166
+ _(attr, state_steps) \
2167
+ _(attr, state_sums) \
2168
+ _(attr, std) \
2169
+ _(attr, step) \
2170
+ _(attr, steps) \
2171
+ _(attr, storage_offset) \
2172
+ _(attr, stride) \
2173
+ _(attr, sum_dy) \
2174
+ _(attr, sum_dy_xmu) \
2175
+ _(attr, sumdim) \
2176
+ _(attr, swap) \
2177
+ _(attr, symmetric_quant) \
2178
+ _(attr, t) \
2179
+ _(attr, tangent) \
2180
+ _(attr, target) \
2181
+ _(attr, target_lengths) \
2182
+ _(attr, targets) \
2183
+ _(attr, tau) \
2184
+ _(attr, tensor) \
2185
+ _(attr, tensor1) \
2186
+ _(attr, tensor2) \
2187
+ _(attr, tensor_indices_or_sections) \
2188
+ _(attr, tensors) \
2189
+ _(attr, tensors1) \
2190
+ _(attr, test_element) \
2191
+ _(attr, test_elements) \
2192
+ _(attr, the_template) \
2193
+ _(attr, theta) \
2194
+ _(attr, thread_masks) \
2195
+ _(attr, threshold) \
2196
+ _(attr, to) \
2197
+ _(attr, tol) \
2198
+ _(attr, total) \
2199
+ _(attr, total_L) \
2200
+ _(attr, total_length) \
2201
+ _(attr, total_weight) \
2202
+ _(attr, train) \
2203
+ _(attr, training) \
2204
+ _(attr, transpose) \
2205
+ _(attr, transpose_result) \
2206
+ _(attr, transposed) \
2207
+ _(attr, type1) \
2208
+ _(attr, type2) \
2209
+ _(attr, unbiased) \
2210
+ _(attr, unitriangular) \
2211
+ _(attr, unpack_data) \
2212
+ _(attr, unpack_pivots) \
2213
+ _(attr, unroll_dim) \
2214
+ _(attr, unsafe) \
2215
+ _(attr, update) \
2216
+ _(attr, upper) \
2217
+ _(attr, upscale_factor) \
2218
+ _(attr, use_cutlass) \
2219
+ _(attr, use_fast_accum) \
2220
+ _(attr, use_gelu) \
2221
+ _(attr, use_input_stats) \
2222
+ _(attr, v) \
2223
+ _(attr, value) \
2224
+ _(attr, values) \
2225
+ _(attr, var) \
2226
+ _(attr, vec) \
2227
+ _(attr, vec1) \
2228
+ _(attr, vec2) \
2229
+ _(attr, w_hh) \
2230
+ _(attr, w_ih) \
2231
+ _(attr, weight) \
2232
+ _(attr, weight0) \
2233
+ _(attr, weight1) \
2234
+ _(attr, weight2) \
2235
+ _(attr, weight3) \
2236
+ _(attr, weight4) \
2237
+ _(attr, weight_arr) \
2238
+ _(attr, weight_buf) \
2239
+ _(attr, weight_decay) \
2240
+ _(attr, weight_g) \
2241
+ _(attr, weight_scale) \
2242
+ _(attr, weight_stride0) \
2243
+ _(attr, weight_zero_point) \
2244
+ _(attr, weights) \
2245
+ _(attr, win_length) \
2246
+ _(attr, window) \
2247
+ _(attr, window_length) \
2248
+ _(attr, window_size) \
2249
+ _(attr, window_size_left) \
2250
+ _(attr, window_size_right) \
2251
+ _(attr, with_replacement) \
2252
+ _(attr, workspace) \
2253
+ _(attr, wrap) \
2254
+ _(attr, x) \
2255
+ _(attr, x1) \
2256
+ _(attr, x2) \
2257
+ _(attr, y) \
2258
+ _(attr, z) \
2259
+ _(attr, z_state) \
2260
+ _(attr, zero_infinity) \
2261
+ _(attr, zero_point) \
2262
+ _(attr, zero_point_hh) \
2263
+ _(attr, zero_point_ih) \
2264
+ _(attr, zero_points)
.venv/lib/python3.11/site-packages/torch/include/ATen/core/blob.h ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <type_traits>
4
+
5
+ #include <c10/util/intrusive_ptr.h>
6
+ #include <c10/util/typeid.h>
7
+ #include <c10/macros/Macros.h>
8
+
9
+ namespace caffe2 {
10
+
11
+ class Tensor;
12
+
13
+ /**
14
+ * @brief Blob is a general container that hosts a typed pointer.
15
+ *
16
+ * A Blob hosts a pointer as well as its type, and takes charge of deleting it
17
+ * properly when the blob is deallocated or re-allocated with a new type. A blob
18
+ * could contain anything, although the most common case is to contain a Tensor.
19
+ */
20
+ class TORCH_API Blob final : public c10::intrusive_ptr_target {
21
+ public:
22
+ /**
23
+ * Initializes an empty Blob.
24
+ */
25
+ Blob() noexcept : meta_() {}
26
+ ~Blob() override {
27
+ Reset();
28
+ }
29
+
30
+ Blob(Blob&& other) noexcept : Blob() {
31
+ swap(other);
32
+ }
33
+
34
+ Blob& operator=(Blob&& other) noexcept {
35
+ Blob(std::move(other)).swap(*this);
36
+ return *this;
37
+ }
38
+
39
+ /**
40
+ * Checks if the content stored in the blob is of type T.
41
+ */
42
+ template <class T>
43
+ bool IsType() const noexcept {
44
+ return meta_.Match<T>();
45
+ }
46
+
47
+ /**
48
+ * Returns the meta info of the blob.
49
+ */
50
+ const TypeMeta meta() const noexcept {
51
+ return meta_;
52
+ }
53
+
54
+ /**
55
+ * Returns a printable typename of the blob.
56
+ */
57
+ c10::string_view TypeName() const noexcept {
58
+ return meta_.name();
59
+ }
60
+
61
+ /**
62
+ * @brief Gets the const reference of the stored object. The code checks if
63
+ * the stored object is of the desired type.
64
+ */
65
+ // TODO(jerryzh): add a Get(c10::DeviceType) function?
66
+ template <class T>
67
+ const T& Get() const {
68
+ TORCH_INTERNAL_ASSERT(
69
+ IsType<T>(),
70
+ "wrong type for the Blob instance. Blob contains ",
71
+ meta_.name(),
72
+ " while caller expects ",
73
+ TypeMeta::TypeName<T>());
74
+ // TODO: after we add Get<Tensor>(c10::DeviceType)
75
+ // and changed all the callsites, we can add
76
+ // a static assert here to enforce T != Tensor
77
+ return *static_cast<const T*>(pointer_);
78
+ }
79
+
80
+ const void* GetRaw() const noexcept {
81
+ return pointer_;
82
+ }
83
+ void* GetRaw() noexcept {
84
+ return pointer_;
85
+ }
86
+
87
+ /**
88
+ * @brief Gets a mutable pointer to the stored object.
89
+ *
90
+ * If the current object is not of the right type, a new object is created
91
+ * and the old object is freed. Note that type T should have a default
92
+ * constructor. Otherwise, create the object yourself first, and use
93
+ * Reset().
94
+ */
95
+ template <class T>
96
+ T* GetMutable() {
97
+ static_assert(
98
+ std::is_default_constructible<T>::value,
99
+ "GetMutable can't be called with non-default-constructible types. "
100
+ "Try using specialized methods");
101
+ if (IsType<T>()) {
102
+ return static_cast<T*>(pointer_);
103
+ } else {
104
+ // TODO Re-enable logging
105
+ // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
106
+ return Reset<T>(new T());
107
+ }
108
+ }
109
+
110
+ template <class T>
111
+ T* GetMutableOrNull() {
112
+ if (IsType<T>()) {
113
+ return static_cast<T*>(pointer_);
114
+ } else {
115
+ return nullptr;
116
+ }
117
+ }
118
+
119
+ /**
120
+ * Sets the underlying object to the allocated one. The Blob then takes over
121
+ * the ownership of the passed in pointer. If there is already an object in
122
+ * the Blob, the old object is freed.
123
+ *
124
+ * This is used when the underlying class T does not have a default ctor, or
125
+ * complex initializations needs to be done outside the blob.
126
+ */
127
+ template <class T>
128
+ T* Reset(T* allocated) {
129
+ free_();
130
+ meta_ = TypeMeta::Make<T>();
131
+ pointer_ = static_cast<void*>(allocated);
132
+ has_ownership_ = true;
133
+ return allocated;
134
+ }
135
+
136
+ /**
137
+ * Sets the underlying object to the allocated one, but does not take over
138
+ * the ownership of the passed in pointer. If there is already an object in
139
+ * the Blob, the old object is freed.
140
+ *
141
+ * Unlike Reset, this does not take over the ownership of the pointer and the
142
+ * caller is responsible for making sure that the lifetime of the allocated
143
+ * blob outlasts the lifetime of any access to this blob, until another Reset
144
+ * call is made or the blob is destructed.
145
+ */
146
+ template <class T>
147
+ std::remove_const_t<T>* ShareExternal(
148
+ std::remove_const_t<T>* allocated) {
149
+ return static_cast<T*>(ShareExternal(
150
+ static_cast<void*>(allocated),
151
+ TypeMeta::Make<std::remove_const_t<T>>()));
152
+ }
153
+
154
+ void* ShareExternal(void* allocated, const TypeMeta meta) {
155
+ free_();
156
+ meta_ = meta;
157
+ pointer_ = allocated;
158
+ has_ownership_ = false;
159
+ return allocated;
160
+ }
161
+
162
+ /**
163
+ * Resets the Blob to an empty one.
164
+ */
165
+ void Reset() {
166
+ free_();
167
+ pointer_ = nullptr;
168
+ meta_ = TypeMeta();
169
+ has_ownership_ = false;
170
+ }
171
+
172
+ /**
173
+ * @brief Swaps the underlying storage of two blobs.
174
+ */
175
+ void swap(Blob& rhs) noexcept {
176
+ using std::swap;
177
+ swap(meta_, rhs.meta_);
178
+ swap(pointer_, rhs.pointer_);
179
+ swap(has_ownership_, rhs.has_ownership_);
180
+ }
181
+
182
+ private:
183
+ void free_() {
184
+ if (has_ownership_ && pointer_ != nullptr) {
185
+ (*meta_.deleteFn())(pointer_);
186
+ }
187
+ }
188
+
189
+ TypeMeta meta_;
190
+ void* pointer_{nullptr};
191
+ bool has_ownership_{false};
192
+
193
+ C10_DISABLE_COPY_AND_ASSIGN(Blob);
194
+ };
195
+
196
+ inline void swap(Blob& lhs, Blob& rhs) noexcept {
197
+ lhs.swap(rhs);
198
+ }
199
+
200
+ inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
201
+ return out << "Blob[" << v.TypeName() << "]";
202
+ }
203
+
204
+ } // namespace caffe2
.venv/lib/python3.11/site-packages/torch/include/ATen/core/function.h ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/function_schema.h>
4
+ #include <ATen/core/ivalue.h>
5
+ #include <ATen/core/qualified_name.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/FunctionRef.h>
8
+
9
+ namespace c10 {
10
+ struct FunctionSchema;
11
+ };
12
+
13
+ namespace at {
14
+ TORCH_API void launch(std::function<void()> func);
15
+ }
16
+
17
+ namespace torch::jit {
18
+
19
+ struct Graph;
20
+ struct Code;
21
+
22
+ namespace mobile {
23
+ struct Code;
24
+ }
25
+
26
+ using Stack = std::vector<at::IValue>;
27
+ using Kwargs = std::unordered_map<std::string, at::IValue>;
28
+ struct RecursiveMethodCallError : public std::exception {};
29
+ using TaskLauncher = std::function<void(std::function<void()>)>;
30
+
31
+ TORCH_API void preoptimizeGraph(
32
+ std::shared_ptr<Graph>& graph,
33
+ bool disable_autocast = false);
34
+
35
+ // A Function is a pure Graph with no implicit `self` object bound.
36
+ // It contains schema information and the executor that manages the
37
+ // execution of the function. Method is a wrapper around an
38
+ // underlying Function that also provides a `self` object.
39
+ struct TORCH_API Function {
40
+ Function() = default;
41
+ Function(const Function&) = default;
42
+ Function& operator=(const Function&) = default;
43
+ Function(Function&&) noexcept = default;
44
+ Function& operator=(Function&&) noexcept = default;
45
+ virtual c10::string_view doc_string() const {
46
+ static constexpr c10::string_view no_doc_string = "";
47
+ return no_doc_string;
48
+ }
49
+
50
+ virtual bool isGraphFunction() const {
51
+ return false;
52
+ }
53
+
54
+ virtual void run(Stack& stack) = 0;
55
+
56
+ virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
57
+ Stack& /*stack*/,
58
+ // NOLINTNEXTLINE(performance-unnecessary-value-param)
59
+ C10_UNUSED TaskLauncher taskLauncher = at::launch) {
60
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
61
+ return {};
62
+ }
63
+
64
+ at::IValue operator()(Stack stack, const Kwargs& kwargs = Kwargs()) {
65
+ getSchema().checkAndNormalizeInputs(stack, kwargs);
66
+ run(stack);
67
+ return stack.front();
68
+ }
69
+
70
+ virtual const c10::QualifiedName& qualname() const = 0;
71
+
72
+ const std::string& name() const {
73
+ return qualname().name();
74
+ }
75
+
76
+ // if this isn't yet defined, run its method_creator function
77
+ virtual void ensure_defined() = 0;
78
+
79
+ virtual const c10::FunctionSchema& getSchema() const = 0;
80
+
81
+ virtual size_t num_inputs() const = 0;
82
+
83
+ virtual Function& setSchema(c10::FunctionSchema schema) = 0;
84
+
85
+ // call() defines how different interpreter implementations interacts with
86
+ // Function objects. Basically interpreters need to provide a callback to
87
+ // communicate to Functions what to do if provided a Code object.
88
+ // Alternatively we could design the signature to return an optional Code
89
+ // object, but that requires special handling the null case in interpreter
90
+ // and the fallback behavior is not well defined by interpreter but rather
91
+ // Function themselves, so a callback approach is more reasonable than
92
+ // returning values.
93
+ // If call() returns true, then callback completes successfully, otherwise
94
+ // call() returns false.
95
+
96
+ // Overload for server interpreter, a bailout size is needed for graph
97
+ // executor.
98
+ virtual bool call(
99
+ Stack&,
100
+ std::optional<size_t>,
101
+ c10::function_ref<void(const Code&)>) {
102
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
103
+ return false;
104
+ }
105
+
106
+ // Overload for mobile interpreter.
107
+ virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) {
108
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
109
+ return false;
110
+ }
111
+
112
+ virtual ~Function() = default;
113
+ };
114
+ } // namespace torch::jit
.venv/lib/python3.11/site-packages/torch/include/ATen/core/functional.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <c10/util/ArrayRef.h>
5
+
6
+ namespace c10 {
7
+
8
+ // The passed in function must take T by value (T), or by
9
+ // const reference (const T&); taking T by non-const reference
10
+ // will result in an error like:
11
+ //
12
+ // error: no type named 'type' in 'class std::invoke_result<foobar::__lambda, T>'
13
+ //
14
+ // No explicit template parameters are required.
15
+
16
+ // Overload for explicit function and ArrayRef
17
+ template<class F, class T>
18
+ inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
19
+ std::vector<decltype(fn(*inputs.begin()))> r;
20
+ r.reserve(inputs.size());
21
+ for(const auto & input : inputs)
22
+ r.push_back(fn(input));
23
+ return r;
24
+ }
25
+
26
+ // C++ forbids taking an address of a constructor, so here's a workaround...
27
+ // Overload for constructor (R) application
28
+ template<typename R, typename T>
29
+ inline std::vector<R> fmap(const T& inputs) {
30
+ std::vector<R> r;
31
+ r.reserve(inputs.size());
32
+ for(auto & input : inputs)
33
+ r.push_back(R(input));
34
+ return r;
35
+ }
36
+
37
+ template<typename F, typename T>
38
+ inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
39
+ std::vector<T> r;
40
+ r.reserve(inputs.size());
41
+ for(auto & input : inputs) {
42
+ if (fn(input)) {
43
+ r.push_back(input);
44
+ }
45
+ }
46
+ return r;
47
+ }
48
+
49
+ template<typename F, typename T>
50
+ inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
51
+ return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
52
+ }
53
+
54
+ } // namespace c10
.venv/lib/python3.11/site-packages/torch/include/ATen/core/ivalue_inl.h ADDED
@@ -0,0 +1,2539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <condition_variable>
4
+ #include <memory>
5
+ #include <optional>
6
+ #include <type_traits>
7
+ #include <utility>
8
+
9
+ #include <ATen/core/Dict.h>
10
+ #include <ATen/core/List.h>
11
+ #include <ATen/core/IListRef.h>
12
+ #include <ATen/core/functional.h>
13
+ #include <ATen/core/jit_type.h>
14
+ #include <ATen/core/qualified_name.h>
15
+ #include <ATen/core/rref_interface.h>
16
+ #include <ATen/core/symbol.h>
17
+ #include <c10/core/DeviceGuard.h>
18
+ #include <c10/core/Event.h>
19
+ #include <c10/core/Scalar.h>
20
+ #include <c10/core/Stream.h>
21
+ #include <c10/core/StreamGuard.h>
22
+ #include <c10/core/TensorImpl.h>
23
+ #include <c10/core/UndefinedTensorImpl.h>
24
+ #include <c10/core/impl/DeviceGuardImplInterface.h>
25
+ #include <c10/util/FunctionRef.h>
26
+ #include <c10/util/Logging.h>
27
+ #include <c10/util/hash.h>
28
+ #include <c10/util/intrusive_ptr.h>
29
+ #include <c10/util/irange.h>
30
+
31
+ namespace torch {
32
+ namespace jit {
33
+ struct Function;
34
+ struct CompilationUnit;
35
+ } // namespace jit
36
+ TORCH_API bool isCustomClass(const c10::IValue& v);
37
+ } // namespace torch
38
+ namespace c10 {
39
+ struct IValue;
40
+ struct ClassType;
41
+ struct TupleType;
42
+ struct EnumType;
43
+ struct InferredType;
44
+
45
+ // For custom class __init__ registration, we need to pass in a function
46
+ // that looks like this: [](IValue x, args...)
47
+
48
+ // However, make_boxed_from_unboxed_functor.h automatically sets the input types
49
+ // of the function by introspecting the types of the functor (which is IValue in
50
+ // this case). However, we need the type it binds to be Foo.
51
+
52
+ // Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
53
+ // which getTypePtr can recover the original class pointer.
54
+
55
+ template <typename TaggedCapsuleType>
56
+ struct tagged_capsule {
57
+ IValue ivalue;
58
+ };
59
+
60
+ template <class T, class NullType>
61
+ c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
62
+ auto t = c10::intrusive_ptr<T, NullType>::reclaim(
63
+ payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
64
+ ? NullType::singleton()
65
+ : static_cast<T*>(payload.u.as_intrusive_ptr));
66
+ clearToNone();
67
+ return t;
68
+ }
69
+ template <typename T, class NullType>
70
+ c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
71
+ if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
72
+ return c10::intrusive_ptr<T, NullType>();
73
+ }
74
+ c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
75
+ return c10::intrusive_ptr<T, NullType>::reclaim(
76
+ static_cast<T*>(payload.u.as_intrusive_ptr));
77
+ }
78
+
79
+ template <class T, class U>
80
+ intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
81
+ return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
82
+ }
83
+
84
+ template <class T, class U>
85
+ intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) {
86
+ return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release()));
87
+ }
88
+
89
+ inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
90
+ AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
91
+ return moveToIntrusivePtr<ivalue::Future>();
92
+ }
93
+ inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const& {
94
+ AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
95
+ return toIntrusivePtr<ivalue::Future>();
96
+ }
97
+ inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() && {
98
+ AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
99
+ return moveToIntrusivePtr<ivalue::Await>();
100
+ }
101
+ inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() const& {
102
+ AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
103
+ return toIntrusivePtr<ivalue::Await>();
104
+ }
105
+ inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && {
106
+ AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
107
+ return moveToIntrusivePtr<c10::RRefInterface>();
108
+ }
109
+ inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const& {
110
+ AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
111
+ return toIntrusivePtr<c10::RRefInterface>();
112
+ }
113
+ inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() && {
114
+ AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
115
+ return moveToIntrusivePtr<at::Quantizer>();
116
+ }
117
+ inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() const& {
118
+ AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
119
+ return toIntrusivePtr<at::Quantizer>();
120
+ }
121
+ inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && {
122
+ AT_ASSERT(isString(), "Expected String but got ", tagKind());
123
+ return moveToIntrusivePtr<ivalue::ConstantString>();
124
+ }
125
+ inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() const& {
126
+ AT_ASSERT(isString(), "Expected String but got ", tagKind());
127
+ return toIntrusivePtr<ivalue::ConstantString>();
128
+ }
129
+ inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() && {
130
+ AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
131
+ return moveToIntrusivePtr<ivalue::Object>();
132
+ }
133
+ inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const& {
134
+ AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
135
+ return toIntrusivePtr<ivalue::Object>();
136
+ }
137
+ inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::
138
+ toPyObjectHolder() && {
139
+ TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
140
+ return moveToIntrusivePtr<ivalue::PyObjectHolder>();
141
+ }
142
+ inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder()
143
+ const& {
144
+ TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
145
+ return toIntrusivePtr<ivalue::PyObjectHolder>();
146
+ }
147
+ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && {
148
+ TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
149
+ return moveToIntrusivePtr<ivalue::EnumHolder>();
150
+ }
151
+ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
152
+ TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
153
+ return toIntrusivePtr<ivalue::EnumHolder>();
154
+ }
155
+ inline c10::complex<double> IValue::toComplexDouble() const {
156
+ TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
157
+ auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
158
+ return (*ptr).val;
159
+ }
160
+ inline at::Tensor IValue::toTensor() && {
161
+ if (C10_UNLIKELY(!isTensor())) {
162
+ reportToTensorTypeError();
163
+ }
164
+ auto result = std::move(payload.as_tensor);
165
+ // As far as I can tell, omitting the usual explicit destructor call
166
+ // is not UB in and of itself, and it's a slight perf win. The
167
+ // destructor is a no-op, because the moved-from Tensor is
168
+ // effectively an intrusive_ptr in the null state, so we don't need
169
+ // the behavior for correctness reasons either. Leaving this
170
+ // explanatory comment, including commented-out destructor call, to
171
+ // make this abundantly clear.
172
+ //
173
+ // payload.as_tensor.~Tensor();
174
+ clearToNone();
175
+ return result;
176
+ }
177
+ inline at::Tensor& IValue::toTensor() & {
178
+ if (C10_UNLIKELY(!isTensor())) {
179
+ reportToTensorTypeError();
180
+ }
181
+ return payload.as_tensor;
182
+ }
183
+ inline const at::Tensor& IValue::toTensor() const& {
184
+ if (C10_UNLIKELY(!isTensor())) {
185
+ reportToTensorTypeError();
186
+ }
187
+ return payload.as_tensor;
188
+ }
189
+ inline c10::Storage IValue::toStorage() && {
190
+ AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
191
+ return c10::Storage(
192
+ moveToIntrusivePtr<at::StorageImpl>());
193
+ }
194
+ inline c10::Storage IValue::toStorage() const& {
195
+ AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
196
+ return c10::Storage(toIntrusivePtr<at::StorageImpl>());
197
+ }
198
+ inline c10::Stream IValue::toStream() && {
199
+ AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
200
+ auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
201
+ return c10::Stream::unpack3((*ptr).val.stream_id,
202
+ (*ptr).val.device_index,
203
+ (*ptr).val.device_type);
204
+ }
205
+ inline c10::Stream IValue::toStream() const& {
206
+ AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
207
+ auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
208
+ return c10::Stream::unpack3((*ptr).val.stream_id,
209
+ (*ptr).val.device_index,
210
+ (*ptr).val.device_type);
211
+ }
212
+ inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && {
213
+ AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
214
+ return moveToIntrusivePtr<caffe2::Blob>();
215
+ }
216
+ inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const& {
217
+ AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
218
+ return toIntrusivePtr<caffe2::Blob>();
219
+ ;
220
+ }
221
+ inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() && {
222
+ TORCH_INTERNAL_ASSERT(isCapsule());
223
+ return moveToIntrusivePtr<torch::CustomClassHolder>();
224
+ }
225
+ inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() const& {
226
+ TORCH_INTERNAL_ASSERT(isCapsule());
227
+ return toIntrusivePtr<torch::CustomClassHolder>();
228
+ }
229
+ inline at::Generator IValue::toGenerator() && {
230
+ AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
231
+ return at::Generator(moveToIntrusivePtr<at::GeneratorImpl>());
232
+ }
233
+ inline at::Generator IValue::toGenerator() const& {
234
+ AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
235
+ return at::Generator(toIntrusivePtr<at::GeneratorImpl>());
236
+ }
237
+ inline c10::SymInt IValue::toSymInt() && {
238
+ AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
239
+ if (isSymInt()) {
240
+ return c10::SymInt(moveToIntrusivePtr<c10::SymNodeImpl>());
241
+ } else {
242
+ return c10::SymInt(payload.u.as_int);
243
+ }
244
+ }
245
+ inline c10::SymInt IValue::toSymInt() const& {
246
+ AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
247
+ if (isSymInt()) {
248
+ return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
249
+ } else {
250
+ return c10::SymInt(payload.u.as_int);
251
+ }
252
+ }
253
+ inline c10::SymFloat IValue::toSymFloat() && {
254
+ AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
255
+ if (isSymFloat()) {
256
+ return c10::SymFloat(moveToIntrusivePtr<c10::SymNodeImpl>());
257
+ } else {
258
+ return c10::SymFloat(payload.u.as_double);
259
+ }
260
+ }
261
+ inline c10::SymFloat IValue::toSymFloat() const& {
262
+ AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
263
+ if (isSymFloat()) {
264
+ return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
265
+ } else {
266
+ return c10::SymFloat(payload.u.as_double);
267
+ }
268
+ }
269
+ inline c10::SymBool IValue::toSymBool() && {
270
+ AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
271
+ if (isSymBool()) {
272
+ return c10::SymBool(moveToIntrusivePtr<c10::SymNodeImpl>());
273
+ } else {
274
+ return c10::SymBool(payload.u.as_bool);
275
+ }
276
+ }
277
+
278
+ inline c10::SymBool IValue::toSymBool() const& {
279
+ AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
280
+ if (isSymBool()) {
281
+ return c10::SymBool(toIntrusivePtr<c10::SymNodeImpl>());
282
+ } else {
283
+ return c10::SymBool(payload.u.as_bool);
284
+ }
285
+ }
286
+
287
+ namespace ivalue {
288
+
289
+ void TORCH_API
290
+ checkCustomClassType(const ClassType* expected_type, const Type* actual_type);
291
+
292
+ template <typename T>
293
+ using Shared = c10::intrusive_ptr<T>;
294
+
295
+ // string
296
+ struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
297
+ private:
298
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
299
+ const std::string str_;
300
+
301
+ public:
302
+ ConstantString(std::string str) : str_(std::move(str)) {}
303
+ ConstantString(c10::string_view str) : str_(std::string(str)) {}
304
+ static c10::intrusive_ptr<ConstantString> create(std::string str_);
305
+ static c10::intrusive_ptr<ConstantString> create(c10::string_view str_);
306
+ static c10::intrusive_ptr<ConstantString> create(const char* str_);
307
+
308
+ const std::string& string() const {
309
+ return str_;
310
+ }
311
+ c10::string_view string_view() const {
312
+ return str_;
313
+ }
314
+
315
+ operator const std::string&() const {
316
+ return string();
317
+ }
318
+ TORCH_API friend std::ostream& operator<<(
319
+ std::ostream& out,
320
+ const ConstantString& v);
321
+ };
322
+
323
+ struct Future;
324
+
325
+ struct TORCH_API TupleElements {
326
+ private:
327
+ size_t inlineSize_;
328
+ // We represent TupleElements this way to save doing a heap
329
+ // allocation in the common (at least for unpickling) case where we
330
+ // have only 3 elements. We have our own union instead of
331
+ // c10::SmallVector<IValue> because c10::SmallVector<IValue> always
332
+ // stores the begin/end/capacity pointers, which would be a waste of
333
+ // space in our use case.
334
+ union {
335
+ std::vector<IValue> elementsVector_;
336
+ // Don't want to declare a std::array because the convenient
337
+ // iteration and size members are a footgun in this case -- the
338
+ // actual size of the array may be smaller than 3!
339
+ // NOLINTNEXTLINE(*c-arrays*)
340
+ IValue elementsInline_[3];
341
+ };
342
+
343
+ void destroyInline() {
344
+ for (const auto ii : c10::irange(inlineSize_)) {
345
+ elementsInline_[ii].~IValue();
346
+ }
347
+ }
348
+ public:
349
+
350
+ using iterator = IValue*;
351
+ using const_iterator = const IValue*;
352
+
353
+ TupleElements() : inlineSize_(0) {
354
+ new (&elementsVector_) std::vector<IValue>();
355
+ }
356
+
357
+ explicit TupleElements(std::vector<IValue> elements)
358
+ : inlineSize_(0), elementsVector_(std::move(elements)) {}
359
+
360
+ explicit TupleElements(c10::ArrayRef<IValue> elements)
361
+ : inlineSize_(elements.size() <= 3 ? elements.size() : 0) {
362
+ switch (inlineSize_) {
363
+ case 3:
364
+ new (&elementsInline_[2]) IValue(elements[2]);
365
+ [[fallthrough]];
366
+ case 2:
367
+ new (&elementsInline_[1]) IValue(elements[1]);
368
+ [[fallthrough]];
369
+ case 1:
370
+ new (&elementsInline_[0]) IValue(elements[0]);
371
+ break;
372
+ case 0:
373
+ new (&elementsVector_) std::vector<IValue>(elements.begin(), elements.end());
374
+ break;
375
+ }
376
+ }
377
+
378
+ explicit TupleElements(IValue&& e1)
379
+ : inlineSize_(1) {
380
+ new (&elementsInline_[0]) IValue(std::move(e1));
381
+ }
382
+
383
+ explicit TupleElements(IValue&& e1, IValue&& e2)
384
+ : inlineSize_(2) {
385
+ new (&elementsInline_[0]) IValue(std::move(e1));
386
+ new (&elementsInline_[1]) IValue(std::move(e2));
387
+ }
388
+
389
+ explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3)
390
+ : inlineSize_(3) {
391
+ new (&elementsInline_[0]) IValue(std::move(e1));
392
+ new (&elementsInline_[1]) IValue(std::move(e2));
393
+ new (&elementsInline_[2]) IValue(std::move(e3));
394
+ }
395
+
396
+ ~TupleElements() {
397
+ if (inlineSize_) {
398
+ destroyInline();
399
+ } else {
400
+ elementsVector_.~vector();
401
+ }
402
+ }
403
+
404
+ // It would be nice to make this noncopyable to prevent people from
405
+ // writing code like `auto output =
406
+ // forward(...).toTupleRef().elements()` (which does refcount bumps on
407
+ // each element, unlike the more efficient but verbose
408
+ // ```
409
+ // auto outputIntrusivePtr = forward(...).toTuple();
410
+ // const auto& output = outputIntrusivePtr->elements();
411
+ // ```
412
+ // ), but there is simply an overwhelming amount of code that does
413
+ // it the inefficient way.
414
+ // See also operator std::vector below.
415
+ TupleElements(const TupleElements& rhs)
416
+ : inlineSize_(rhs.inlineSize_) {
417
+ if (rhs.inlineSize_) {
418
+ for (const auto ii : c10::irange(inlineSize_)) {
419
+ new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
420
+ }
421
+ } else {
422
+ new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
423
+ }
424
+ }
425
+
426
+ TupleElements& operator=(const TupleElements& rhs) {
427
+ if (inlineSize_) {
428
+ if (rhs.inlineSize_) {
429
+ for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
430
+ elementsInline_[ii] = rhs.elementsInline_[ii];
431
+ }
432
+ if (rhs.inlineSize_ > inlineSize_) {
433
+ for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
434
+ new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
435
+ }
436
+ } else {
437
+ for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
438
+ elementsInline_[ii].~IValue();
439
+ }
440
+ }
441
+ } else {
442
+ destroyInline();
443
+ new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
444
+ }
445
+ } else {
446
+ if (rhs.inlineSize_) {
447
+ elementsVector_.~vector();
448
+ for (const auto ii : c10::irange(rhs.inlineSize_)) {
449
+ new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
450
+ }
451
+ } else {
452
+ elementsVector_ = rhs.elementsVector_;
453
+ }
454
+ }
455
+ inlineSize_ = rhs.inlineSize_;
456
+ return *this;
457
+ }
458
+
459
+ TupleElements(TupleElements&& rhs) noexcept
460
+ : inlineSize_(rhs.inlineSize_) {
461
+ if (inlineSize_) {
462
+ for (const auto ii : c10::irange(inlineSize_)) {
463
+ new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
464
+ }
465
+ } else {
466
+ new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
467
+ }
468
+ }
469
+
470
+ TupleElements& operator=(TupleElements&& rhs) noexcept {
471
+ if (inlineSize_) {
472
+ if (rhs.inlineSize_) {
473
+ for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
474
+ elementsInline_[ii] = std::move(rhs.elementsInline_[ii]);
475
+ }
476
+ if (rhs.inlineSize_ > inlineSize_) {
477
+ for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
478
+ new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
479
+ }
480
+ } else {
481
+ for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
482
+ elementsInline_[ii].~IValue();
483
+ }
484
+ }
485
+ } else {
486
+ destroyInline();
487
+ new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
488
+ }
489
+ } else {
490
+ if (rhs.inlineSize_) {
491
+ elementsVector_.~vector();
492
+ for (const auto ii : c10::irange(rhs.inlineSize_)) {
493
+ new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
494
+ }
495
+ } else {
496
+ elementsVector_ = std::move(rhs.elementsVector_);
497
+ }
498
+ }
499
+ inlineSize_ = rhs.inlineSize_;
500
+ return *this;
501
+ }
502
+
503
+ C10_NODISCARD c10::ArrayRef<IValue> asArrayRef() const {
504
+ if (inlineSize_) {
505
+ return c10::ArrayRef<IValue>(elementsInline_, inlineSize_);
506
+ } else {
507
+ return elementsVector_;
508
+ }
509
+ }
510
+
511
+ // Mimic implicit conversion from std::vector to ArrayRef.
512
+ operator c10::ArrayRef<IValue>() const {
513
+ return asArrayRef();
514
+ }
515
+
516
+ static size_t hash(const TupleElements& v) {
517
+ return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef());
518
+ }
519
+
520
+ void setContents(std::vector<IValue>&& contents) {
521
+ if (inlineSize_) {
522
+ destroyInline();
523
+ new (&elementsVector_) std::vector<IValue>(std::move(contents));
524
+ inlineSize_ = 0;
525
+ } else {
526
+ elementsVector_ = std::move(contents);
527
+ }
528
+ }
529
+
530
+ C10_NODISCARD bool empty() const {
531
+ return inlineSize_ ? false : elementsVector_.empty();
532
+ }
533
+
534
+ C10_NODISCARD size_t size() const {
535
+ return inlineSize_ ? inlineSize_ : elementsVector_.size();
536
+ }
537
+
538
+ C10_NODISCARD IValue& operator[](size_t idx) {
539
+ if (inlineSize_) {
540
+ return elementsInline_[idx];
541
+ } else {
542
+ return elementsVector_[idx];
543
+ }
544
+ }
545
+
546
+ C10_NODISCARD const IValue& operator[](size_t idx) const {
547
+ if (inlineSize_) {
548
+ return elementsInline_[idx];
549
+ } else {
550
+ return elementsVector_[idx];
551
+ }
552
+ }
553
+
554
+ C10_NODISCARD IValue& at(size_t idx) {
555
+ if (inlineSize_) {
556
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
557
+ TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
558
+ return elementsInline_[idx];
559
+ } else {
560
+ return elementsVector_.at(idx);
561
+ }
562
+ }
563
+
564
+ C10_NODISCARD const IValue& at(size_t idx) const {
565
+ if (inlineSize_) {
566
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
567
+ TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
568
+ return elementsInline_[idx];
569
+ } else {
570
+ TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size());
571
+ return elementsVector_.at(idx);
572
+ }
573
+ }
574
+
575
+ C10_NODISCARD iterator begin() {
576
+ if (inlineSize_) {
577
+ return elementsInline_;
578
+ } else {
579
+ return elementsVector_.data();
580
+ }
581
+ }
582
+
583
+ C10_NODISCARD iterator end() {
584
+ if (inlineSize_) {
585
+ return elementsInline_ + inlineSize_;
586
+ } else {
587
+ return elementsVector_.data() + elementsVector_.size();
588
+ }
589
+ }
590
+
591
+ C10_NODISCARD const_iterator begin() const {
592
+ if (inlineSize_) {
593
+ return elementsInline_;
594
+ } else {
595
+ return elementsVector_.data();
596
+ }
597
+ }
598
+
599
+ C10_NODISCARD const_iterator end() const {
600
+ if (inlineSize_) {
601
+ return elementsInline_ + inlineSize_;
602
+ } else {
603
+ return elementsVector_.data() + elementsVector_.size();
604
+ }
605
+ }
606
+
607
+ C10_NODISCARD const_iterator cbegin() const {
608
+ return begin();
609
+ }
610
+
611
+ C10_NODISCARD const_iterator cend() const {
612
+ return end();
613
+ }
614
+
615
+ C10_NODISCARD std::vector<IValue> vec() const & {
616
+ return asArrayRef().vec();
617
+ }
618
+
619
+ C10_NODISCARD IValue& back() {
620
+ return *(end() - 1);
621
+ }
622
+
623
+ C10_NODISCARD const IValue& back() const {
624
+ return *(end() - 1);
625
+ }
626
+
627
+ C10_NODISCARD std::vector<IValue> vec() && {
628
+ std::vector<IValue> result;
629
+ result.reserve(size());
630
+ for (auto&& iv : *this) {
631
+ result.push_back(std::move(iv));
632
+ }
633
+ return result;
634
+ }
635
+
636
+ // More compatibility shims for the overwhelming amount of code that
637
+ // likes to copy tuple elements into a vector; see comment above the
638
+ // copy constructor.
639
+ operator std::vector<IValue>() const & {
640
+ return vec();
641
+ }
642
+
643
+ operator std::vector<IValue>() && {
644
+ return vec();
645
+ }
646
+ };
647
+
648
+ template <typename T>
649
+ struct TupleTypeFactory {};
650
+
651
+ template <>
652
+ struct TORCH_API TupleTypeFactory<TupleType> {
653
+ static TupleTypePtr create(std::vector<TypePtr> types) {
654
+ return TupleType::create(std::move(types));
655
+ }
656
+ static TupleTypePtr fallback(const Type& type);
657
+ };
658
+
659
+ template <>
660
+ struct TORCH_API TupleTypeFactory<c10::DynamicType> {
661
+ static DynamicTypePtr create(const std::vector<TypePtr>& elemTypes);
662
+ static DynamicTypePtr fallback(const Type&);
663
+ };
664
+
665
+ struct TORCH_API Tuple : c10::intrusive_ptr_target {
666
+ private:
667
+ TupleElements elements_;
668
+ mutable c10::TypePtr type_; // lazily computed for unnamed tuples
669
+
670
+ public:
671
+ // named tuples have additional type information, so we
672
+ // directly create them tagged
673
+ static c10::intrusive_ptr<Tuple> createNamed(
674
+ std::vector<IValue> elements_,
675
+ c10::TypePtr type_) {
676
+ return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
677
+ }
678
+
679
+ static c10::intrusive_ptr<Tuple> createNamed(
680
+ TupleElements elements_,
681
+ std::shared_ptr<TupleType> type_) {
682
+ return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
683
+ }
684
+
685
+ static c10::intrusive_ptr<Tuple> createNamed(
686
+ std::initializer_list<IValue> elements_,
687
+ std::shared_ptr<TupleType> type_) {
688
+ return createNamed(TupleElements(c10::ArrayRef<IValue>(elements_)), std::move(type_));
689
+ }
690
+
691
+ // MSVC apparently can't disambiguate the other two overloads of
692
+ // create when passed an initializer_list without this.
693
+ static c10::intrusive_ptr<Tuple> create(std::initializer_list<IValue> elements_) {
694
+ return create(c10::ArrayRef<IValue>(elements_));
695
+ }
696
+
697
+ static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
698
+ return c10::make_intrusive<Tuple>(std::move(elements_));
699
+ }
700
+
701
+ static c10::intrusive_ptr<Tuple> create(TupleElements elements_) {
702
+ return c10::make_intrusive<Tuple>(std::move(elements_));
703
+ }
704
+
705
+ static c10::intrusive_ptr<Tuple> create(c10::ArrayRef<IValue> elements_) {
706
+ return create(TupleElements(elements_));
707
+ }
708
+
709
+ static c10::intrusive_ptr<Tuple> create(IValue e1) {
710
+ return c10::make_intrusive<Tuple>(std::move(e1));
711
+ }
712
+
713
+ static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2) {
714
+ return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2));
715
+ }
716
+
717
+ static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2, IValue e3) {
718
+ return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2), std::move(e3));
719
+ }
720
+
721
+ private:
722
+ // Workaround inability to use `>` operator in template argument list.
723
+ template <typename... Args>
724
+ static constexpr bool hasMoreThanThreeArgs() {
725
+ return sizeof...(Args) > 3;
726
+ }
727
+
728
+ public:
729
+ template <typename... Args>
730
+ static c10::intrusive_ptr<Tuple> create(Args&&... elements_) {
731
+ switch (sizeof...(Args)) {
732
+ case 1:
733
+ case 2:
734
+ case 3:
735
+ return create(IValue(std::forward<Args>(elements_))...);
736
+ default:
737
+ return create(
738
+ std::vector<IValue>{IValue(std::forward<Args>(elements_))...});
739
+ }
740
+ }
741
+
742
+ // Again, it would be nice to make this noncopyable, but there's a
743
+ // lot of extant code that copies Tuples.
744
+ // Tuple(const Tuple& rhs) = delete;
745
+
746
+ const TupleElements& elements() const& {
747
+ return elements_;
748
+ }
749
+
750
+ TupleElements elements() && {
751
+ return std::move(elements_);
752
+ }
753
+
754
+ void setElements(std::vector<IValue>&& elements) {
755
+ elements_.setContents(std::move(elements));
756
+ }
757
+
758
+ void setElements(TupleElements&& elements) {
759
+ elements_ = std::move(elements);
760
+ }
761
+
762
+ void unsafeSetElement(size_t idx, const IValue& element) {
763
+ elements_[idx] = element;
764
+ }
765
+
766
+ void unsafeSetElement(size_t idx, IValue&& element) {
767
+ elements_[idx] = std::move(element);
768
+ }
769
+
770
+ size_t size() const {
771
+ return elements_.size();
772
+ }
773
+
774
+ template <typename T = c10::TupleType>
775
+ std::shared_ptr<T> type() const {
776
+ if (!type_) {
777
+ type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) {
778
+ return v.type<typename T::ElementType>();
779
+ }));
780
+ }
781
+ if (auto t = type_->cast<T>()) {
782
+ return t;
783
+ }
784
+ return TupleTypeFactory<T>::fallback(*type_);
785
+ }
786
+
787
+ static size_t hash(const Tuple& t) {
788
+ return c10::get_hash(t.elements());
789
+ }
790
+
791
+ TORCH_API friend bool operator==(
792
+ const ivalue::Tuple& lhs,
793
+ const ivalue::Tuple& rhs);
794
+
795
+ private:
796
+ // NOTE: If we try to avoid the overloads without
797
+ // `std::shared_ptr<TupleType> type` by defaulting it to nullptr, we
798
+ // end up having to call (part of) the shared_ptr destructor for
799
+ // `type` even though we should know statically it won't do
800
+ // anything.
801
+ explicit Tuple(std::vector<IValue> elements)
802
+ : elements_(std::move(elements)){}
803
+
804
+ explicit Tuple(std::vector<IValue> elements, c10::TypePtr type)
805
+ : elements_(std::move(elements)), type_(std::move(type)) {}
806
+
807
+ explicit Tuple(TupleElements&& elements)
808
+ : elements_(std::move(elements)) {}
809
+
810
+ explicit Tuple(TupleElements&& elements, std::shared_ptr<TupleType> type)
811
+ : elements_(std::move(elements)), type_(std::move(type)) {}
812
+
813
+ explicit Tuple(IValue&& e1)
814
+ : elements_(std::move(e1)) {}
815
+
816
+ explicit Tuple(IValue&& e1, std::shared_ptr<TupleType> type)
817
+ : elements_(std::move(e1)), type_(std::move(type)) {}
818
+
819
+ explicit Tuple(IValue&& e1, IValue&& e2)
820
+ : elements_(std::move(e1), std::move(e2)) {}
821
+
822
+ explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr<TupleType> type)
823
+ : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {}
824
+
825
+ explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3)
826
+ : elements_(std::move(e1), std::move(e2), std::move(e3)) {}
827
+
828
+ explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr<TupleType> type)
829
+ : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {}
830
+
831
+ friend class c10::intrusive_ptr<Tuple>;
832
+ };
833
+
834
+ struct Object;
835
+ struct PyObjectHolder;
836
+ struct EnumHolder;
837
+ } // namespace ivalue
838
+
839
+ // Future
840
+ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
841
+ private:
842
+ // Keep this private in order to force users to go through make_intrusive and
843
+ // thus prevent creating a Future that's not held by an intrusive_ptr.
844
+ explicit Future(TypePtr type, std::vector<c10::Device> devices={})
845
+ : type_(std::move(type)),
846
+ impl_(getTypeOfDevices(devices)),
847
+ devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {}
848
+
849
+ friend c10::intrusive_ptr<Future>;
850
+
851
+ struct FutureCallback {
852
+ std::function<void(Future&)> callback;
853
+ bool uses_future; // whether the Future& passed in is actually used
854
+
855
+ template <typename T>
856
+ FutureCallback(T callback, bool uses_future)
857
+ : callback(std::move(callback)), uses_future(uses_future) {}
858
+ };
859
+
860
+ public:
861
+ Future(const Future&) = delete;
862
+ Future(Future&&) = delete;
863
+ Future& operator=(const Future&) = delete;
864
+ Future& operator=(Future&&) = delete;
865
+
866
+ struct TORCH_API FutureError final : public std::exception {
867
+ explicit FutureError(std::string&& error_msg_)
868
+ : error_msg(std::move(error_msg_)) {}
869
+
870
+ FutureError() = default;
871
+
872
+ const char* what() const noexcept override {
873
+ return error_msg.c_str();
874
+ }
875
+
876
+ std::string error_msg;
877
+ };
878
+
879
+ /**
880
+ * Wait on the future until it completes.
881
+ */
882
+ void wait() {
883
+ std::unique_lock<std::mutex> lock(mutex_);
884
+ finished_cv_.wait(lock, [&]() -> bool { return completed_; });
885
+ synchronizeWithCurrentStreams();
886
+ }
887
+
888
+ /**
889
+ * Wait on the future until it completes and throw an
890
+ * exception if an error exists.
891
+ */
892
+ void waitAndThrow() {
893
+ wait();
894
+
895
+ if (eptr_) {
896
+ std::rethrow_exception(eptr_);
897
+ }
898
+ }
899
+
900
+ /**
901
+ * Explicitly mark the future as completed with the output value. Optionally,
902
+ * the storages for all tensors in IValue can be passed as well. The DataPtrs
903
+ * of these storages are used to synchronize CUDA streams. If storages isn't
904
+ * given we will attempt to extract it from the value, if we need to (this
905
+ * happens if a non-empty set of devices was given to the constructor). Thus
906
+ * one only needs to provide storages when 1) they cannot be extracted through
907
+ * IValue::getSubValues() or through pickling in case of Python object; or
908
+ * when 2) customized storage extraction is more efficient.
909
+ */
910
+ using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>;
911
+ void markCompleted(
912
+ IValue value,
913
+ std::optional<std::vector<WeakStorage>> storages = std::nullopt) {
914
+ // Start by performing all steps that can throw, before setting any field.
915
+ // Do this before even acquiring the mutex, because extractStorages might
916
+ // acquire the GIL, which could lead to a lock inversion with our mutex.
917
+ // See https://github.com/pytorch/pytorch/issues/58239.
918
+ std::vector<WeakStorage> actualStorages;
919
+ std::vector<c10::Device> usedDevices;
920
+ try {
921
+ // FIXME We should always extract DataPtrs, in order to catch the case of
922
+ // users using CUDA values but forgetting to set devices, which currently
923
+ // leads to a silent synchronization/correctness issue. However, as this
924
+ // might worsen perf in CPU-only cases, we should only do so after careful
925
+ // benchmarks.
926
+ if (impl_.type() != c10::kCPU) {
927
+ actualStorages =
928
+ storages.has_value() ? std::move(*storages) : extractStorages(value);
929
+ usedDevices = getDevicesOfStorages(impl_, actualStorages);
930
+ ensureIsSubsetOfDevices(usedDevices, devices_);
931
+ }
932
+ } catch (const std::exception&) {
933
+ setError(std::current_exception());
934
+ return;
935
+ }
936
+
937
+ std::unique_lock<std::mutex> lock(mutex_);
938
+ TORCH_CHECK(
939
+ !completed(),
940
+ "Attempting to mark a completed Future as complete again. Note that "
941
+ "a Future can only be marked completed once.");
942
+
943
+ // Only set value_ and completed_ flag once all checks and preparation steps
944
+ // have returned successfully to allow for proper error propagation.
945
+ value_ = std::move(value);
946
+ completed_ = true;
947
+
948
+ currentDevice_ = impl_.getDevice();
949
+ storages_ = std::move(actualStorages);
950
+ for (const c10::Device& device : usedDevices) {
951
+ c10::Event event(impl_.type());
952
+ event.record(impl_.getStream(device));
953
+ events_.push_back(std::move(event));
954
+ }
955
+
956
+ std::vector<FutureCallback> cbs;
957
+ cbs.swap(callbacks_);
958
+ lock.unlock();
959
+
960
+ finished_cv_.notify_all();
961
+ for (auto& callback : cbs) {
962
+ invokeCallback(std::move(callback.callback), callback.uses_future);
963
+ }
964
+ }
965
+
966
+ void markCompleted() {
967
+ markCompleted(IValue{});
968
+ }
969
+
970
+ void setError(std::exception_ptr eptr) {
971
+ std::unique_lock<std::mutex> lock(mutex_);
972
+ setErrorInternal(std::move(eptr), lock);
973
+ }
974
+
975
+ void setErrorIfNeeded(std::exception_ptr eptr) {
976
+ std::unique_lock<std::mutex> lock(mutex_);
977
+ if (completed_) {
978
+ // This should be rare and shouldn't cause log spew. Its important to
979
+ // log errors and thats why we have this log here.
980
+ std::string msg = c10::str(
981
+ "Skipping setting following error on the Future since "
982
+ "it is already marked completed (this is not necessarily "
983
+ "an error):\n",
984
+ tryRetrieveErrorMessageInternal(std::move(eptr)));
985
+ if (eptr_) {
986
+ msg += c10::str(
987
+ ", \nOriginal exception:\n",
988
+ tryRetrieveErrorMessageInternal(eptr_));
989
+ }
990
+ LOG(INFO) << msg;
991
+ return;
992
+ } else {
993
+ setErrorInternal(std::move(eptr), lock);
994
+ }
995
+ }
996
+
997
+ // Get the result of the current future.
998
+ IValue value() {
999
+ std::unique_lock<std::mutex> lock(mutex_);
1000
+ AT_ASSERT(completed());
1001
+ if (eptr_) {
1002
+ std::rethrow_exception(eptr_);
1003
+ }
1004
+ return value_;
1005
+ }
1006
+
1007
+ // This accessor should only be used if we know that the future is
1008
+ // completed() with no error.
1009
+ const IValue& constValue() const {
1010
+ std::unique_lock<std::mutex> lock(mutex_);
1011
+ AT_ASSERT(completed());
1012
+ TORCH_INTERNAL_ASSERT(
1013
+ !eptr_,
1014
+ "value() accessor should only be used when future is not completed with ",
1015
+ "an error, but future had the following error: ",
1016
+ tryRetrieveErrorMessageInternal(eptr_)
1017
+ );
1018
+ return value_;
1019
+ }
1020
+
1021
+ // This accessor should only be used if we know that the future is
1022
+ // completed() with no error.
1023
+ const std::vector<WeakStorage>& storages() const {
1024
+ std::unique_lock<std::mutex> lock(mutex_);
1025
+ AT_ASSERT(completed());
1026
+ AT_ASSERT(!eptr_);
1027
+ return storages_;
1028
+ }
1029
+
1030
+ /**
1031
+ * Add a callback to the future.
1032
+ * The callbacks will be executed once the future completes.
1033
+ * If the future has already completed,
1034
+ * this function will execute the callback immediately.
1035
+ */
1036
+ template <typename T>
1037
+ void addCallback(T callback, bool uses_future = true) {
1038
+ static_assert(
1039
+ std::is_invocable_r<void, T, Future&>::value,
1040
+ "The callback must have signature void(Future&)");
1041
+
1042
+ std::unique_lock<std::mutex> lock(mutex_);
1043
+ if (completed()) {
1044
+ lock.unlock();
1045
+ invokeCallback(std::move(callback), uses_future);
1046
+ return;
1047
+ }
1048
+ callbacks_.emplace_back(std::move(callback), uses_future);
1049
+ }
1050
+
1051
+ /**
1052
+ * Add a callback to the future, and return another Future to hold the return
1053
+ * value of the callback. This is necessary when the callback provider needs
1054
+ * to know for sure when the callback has finished.
1055
+ */
1056
+ template <typename T>
1057
+ c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
1058
+ using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
1059
+ static_assert(
1060
+ std::disjunction<
1061
+ std::is_invocable_r<IValue, T, Future&>,
1062
+ std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
1063
+ "The callback must have signature IValue(Future&) or "
1064
+ "std::tuple<IValue, std::vector<Storage>>(Future&)");
1065
+
1066
+ auto childFut = createInstance(::std::move(type));
1067
+ addCallback([childFut,
1068
+ cb = std::move(callback)](Future& parentFut) mutable {
1069
+ try {
1070
+ if constexpr (::std::is_convertible_v<typename std::invoke_result_t<T &&, Future&>, IValueWithStorages>) {
1071
+ auto [ivalue, storages] = cb(parentFut);
1072
+ childFut->markCompleted(::std::move(ivalue), ::std::move(storages));
1073
+ } else {
1074
+ childFut->markCompleted(cb(parentFut));
1075
+ }
1076
+ } catch (std::exception&) {
1077
+ childFut->setError(std::current_exception());
1078
+ }
1079
+ });
1080
+ return childFut;
1081
+ }
1082
+
1083
+ template <typename T>
1084
+ c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) {
1085
+ static_assert(
1086
+ std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value,
1087
+ "The callback must have signature c10::intrusive_ptr<Future>(Future&)");
1088
+
1089
+ auto childFut = createInstance(std::move(type));
1090
+ addCallback(
1091
+ [childFut, cb = std::move(callback)](Future& parentFut) mutable {
1092
+ c10::intrusive_ptr<Future> intermediateFut;
1093
+ try {
1094
+ intermediateFut = cb(parentFut);
1095
+ } catch (std::exception&) {
1096
+ childFut->setError(std::current_exception());
1097
+ return;
1098
+ }
1099
+ intermediateFut->addCallback(
1100
+ [childFut = std::move(childFut)](Future& intermediateFut) {
1101
+ if (intermediateFut.hasError()) {
1102
+ childFut->setError(intermediateFut.exception_ptr());
1103
+ } else {
1104
+ childFut->markCompleted(
1105
+ intermediateFut.value(), intermediateFut.storages());
1106
+ }
1107
+ });
1108
+ });
1109
+ return childFut;
1110
+ }
1111
+
1112
+ // Tries to retrieve the error message from std::exception_ptr.
1113
+ std::string tryRetrieveErrorMessage() const {
1114
+ TORCH_CHECK(hasError(), "No error present on the future.");
1115
+ std::unique_lock<std::mutex> lock(mutex_);
1116
+ return tryRetrieveErrorMessageInternal(eptr_);
1117
+ }
1118
+
1119
+ // Check if the current future has completed
1120
+ bool completed() const {
1121
+ return completed_;
1122
+ }
1123
+
1124
+ bool hasValue() const {
1125
+ std::unique_lock<std::mutex> lock(mutex_);
1126
+ return completed_ && !eptr_;
1127
+ }
1128
+
1129
+ bool hasError() const {
1130
+ std::unique_lock<std::mutex> lock(mutex_);
1131
+ return eptr_ ? true : false;
1132
+ }
1133
+
1134
+ std::exception_ptr exception_ptr() const {
1135
+ std::unique_lock<std::mutex> lock(mutex_);
1136
+ return eptr_;
1137
+ }
1138
+
1139
+ TORCH_API friend std::ostream& operator<<(
1140
+ std::ostream& out,
1141
+ const Future& v);
1142
+
1143
+ const TypePtr& elementType() const {
1144
+ return type_;
1145
+ }
1146
+
1147
+ const std::vector<c10::Device>& devices() const {
1148
+ return devices_;
1149
+ }
1150
+
1151
+ // This method should be used when one intends to manually create a child
1152
+ // future, for example when implementing a customized version of then().
1153
+ c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
1154
+ return c10::make_intrusive<Future>(std::move(type), devices_);
1155
+ }
1156
+
1157
+ private:
1158
+
1159
+ // This method should always be used when invoking a callback (regardless of
1160
+ // how/when that happens) as it will ensure that the proper "environment" is
1161
+ // set up before running the callback, as in, it will set up the CUDA streams,
1162
+ // synchronize them with the value, and so on (if needed).
1163
+ template<typename T>
1164
+ void invokeCallback(T callback, bool uses_future) {
1165
+ static_assert(
1166
+ std::is_invocable_r<void, T, Future&>::value,
1167
+ "The callback must have signature void(Future&)");
1168
+
1169
+ // The synchronization performed below shouldn't be needed when the future
1170
+ // is not used by the callback.
1171
+ if (uses_future) {
1172
+ c10::OptionalDeviceGuard deviceGuard(currentDevice_);
1173
+
1174
+ std::vector<c10::Stream> streams;
1175
+ streams.reserve(devices_.size());
1176
+ for (const c10::Device& device : devices_) {
1177
+ streams.push_back(impl_.getStreamFromGlobalPool(device));
1178
+ }
1179
+ c10::MultiStreamGuard streamGuard(streams);
1180
+ synchronizeWithCurrentStreams();
1181
+ callback(*this);
1182
+ } else {
1183
+ callback(*this);
1184
+ }
1185
+ }
1186
+
1187
+ // This method should be called before this future's value is used, as it
1188
+ // ensures that the CUDA streams that are "current" at the callsite properly
1189
+ // synchronize with the value.
1190
+ void synchronizeWithCurrentStreams() {
1191
+ for (c10::Event& event : events_) {
1192
+ event.block(impl_.getStream(event.device()));
1193
+ }
1194
+
1195
+ for (const WeakStorage& weak_storage : storages_) {
1196
+ c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1197
+ if (!storage) {
1198
+ continue;
1199
+ }
1200
+ if (!storage->device().is_cpu()) {
1201
+ impl_.recordDataPtrOnStream(
1202
+ storage->data_ptr(), impl_.getStream(storage->device()));
1203
+ }
1204
+ }
1205
+ }
1206
+
1207
+ void setErrorInternal(
1208
+ std::exception_ptr eptr,
1209
+ std::unique_lock<std::mutex>& lock) {
1210
+ TORCH_CHECK(
1211
+ !eptr_,
1212
+ "Error already set on this Future: ",
1213
+ tryRetrieveErrorMessageInternal(eptr_),
1214
+ ", trying to set error: ",
1215
+ tryRetrieveErrorMessageInternal(eptr));
1216
+ TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed");
1217
+ completed_ = true;
1218
+ eptr_ = std::move(eptr);
1219
+
1220
+ std::vector<FutureCallback> cbs;
1221
+ cbs.swap(callbacks_);
1222
+ lock.unlock();
1223
+
1224
+ finished_cv_.notify_all();
1225
+ for (auto& callback : cbs) {
1226
+ invokeCallback(std::move(callback.callback), callback.uses_future);
1227
+ }
1228
+ }
1229
+
1230
+ // Tries to retrieve the error message from std::exception_ptr.
1231
+ std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
1232
+ try {
1233
+ std::rethrow_exception(std::move(eptr));
1234
+ } catch (const std::exception& e) {
1235
+ return e.what();
1236
+ } catch (...) {
1237
+ return "Unknown Exception Type";
1238
+ }
1239
+ }
1240
+
1241
+ // Defined in ivalue.cpp.
1242
+ static std::vector<WeakStorage> extractStorages(
1243
+ const at::IValue& value);
1244
+
1245
+ static std::vector<c10::Device> getDevicesOfStorages(
1246
+ const c10::impl::VirtualGuardImpl& impl,
1247
+ const std::vector<WeakStorage>& storages) {
1248
+ c10::DeviceIndex deviceCount = impl.deviceCount();
1249
+ std::vector<bool> isDeviceUsed(deviceCount, false);
1250
+ for (const WeakStorage& weak_storage : storages) {
1251
+ c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1252
+ if (!storage) {
1253
+ continue;
1254
+ }
1255
+ c10::Device device = storage->device();
1256
+ if (!device.is_cpu()) {
1257
+ TORCH_CHECK_VALUE(
1258
+ device.type() == impl.type(),
1259
+ "Expected all data ptrs to be on a device of type ",
1260
+ impl.type(),
1261
+ ", got one on device ",
1262
+ device);
1263
+ isDeviceUsed[device.index()] = true;
1264
+ }
1265
+ }
1266
+ std::vector<c10::Device> devices;
1267
+ for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) {
1268
+ if (isDeviceUsed[idx]) {
1269
+ devices.emplace_back(impl.type(), idx);
1270
+ }
1271
+ }
1272
+ return devices;
1273
+ }
1274
+
1275
+ static std::string formatSetOfDevices(
1276
+ const std::vector<c10::Device>& devices) {
1277
+ if (devices.empty()) {
1278
+ return "(none)";
1279
+ }
1280
+ std::ostringstream oss;
1281
+ oss << devices[0];
1282
+ for (const auto idx : c10::irange(1, devices.size())) {
1283
+ if (idx == devices.size() - 1) {
1284
+ oss << " and ";
1285
+ } else {
1286
+ oss << ", ";
1287
+ }
1288
+ oss << devices[idx];
1289
+ }
1290
+ return oss.str();
1291
+ }
1292
+
1293
+ static c10::DeviceType getTypeOfDevices(
1294
+ const std::vector<c10::Device>& devices) {
1295
+ if (devices.empty()) {
1296
+ return c10::kCPU;
1297
+ }
1298
+ c10::DeviceType deviceType = devices[0].type();
1299
+ for (const auto idx : c10::irange(1, devices.size())) {
1300
+ TORCH_CHECK_VALUE(
1301
+ devices[idx].type() == deviceType,
1302
+ "Expected all devices to be of the same type, but got a mismatch between ",
1303
+ devices[0],
1304
+ " and ",
1305
+ devices[idx]);
1306
+ }
1307
+ return deviceType;
1308
+ }
1309
+
1310
+ // We need devices to be sorted in order to use ensureIsSubsetOfDevices.
1311
+ static std::vector<c10::Device> sortAndDeduplicateDevices(
1312
+ const c10::impl::VirtualGuardImpl& /*impl*/,
1313
+ std::vector<c10::Device> devices) {
1314
+ std::sort(
1315
+ devices.begin(), devices.end(),
1316
+ [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1317
+ // Deduplicate by compacting.
1318
+ size_t targetIdx = 0;
1319
+ for (const auto sourceIdx : c10::irange(devices.size())) {
1320
+ TORCH_CHECK_VALUE(
1321
+ devices[sourceIdx].has_index(),
1322
+ "Expected devices to have indices, got ", devices[sourceIdx]);
1323
+ if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) {
1324
+ // It's a duplicate, skip it.
1325
+ continue;
1326
+ }
1327
+ if (sourceIdx != targetIdx) {
1328
+ devices[targetIdx] = devices[sourceIdx];
1329
+ }
1330
+ targetIdx++;
1331
+ }
1332
+ // If there were duplicates there's now a gap at the end: trim it. Resizing
1333
+ // requires the item type to be default-constructible (which c10::Device is
1334
+ // not) because in principle it could be required to create new items. Since
1335
+ // we know we'll shrink the vector, we provide a custom dummy value instead.
1336
+ devices.resize(targetIdx, c10::Device(c10::kCPU));
1337
+ return devices;
1338
+ }
1339
+
1340
+ static void ensureIsSubsetOfDevices(
1341
+ const std::vector<c10::Device>& subset,
1342
+ const std::vector<c10::Device>& superset) {
1343
+ // We assume the devices in both vectors have the same consistent type, and
1344
+ // their indices are unique and sorted.
1345
+ std::vector<c10::Device> excessDevices;
1346
+ std::set_difference(
1347
+ subset.begin(),
1348
+ subset.end(),
1349
+ superset.begin(),
1350
+ superset.end(),
1351
+ std::back_inserter(excessDevices),
1352
+ [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1353
+ TORCH_CHECK_VALUE(
1354
+ excessDevices.empty(),
1355
+ "The result contained tensors residing on device(s) ",
1356
+ formatSetOfDevices(excessDevices),
1357
+ " which are not among the expected device(s) ",
1358
+ formatSetOfDevices(superset));
1359
+ }
1360
+
1361
+ mutable std::mutex mutex_;
1362
+ std::atomic_bool completed_ = {false}; // is this future complete
1363
+ std::condition_variable finished_cv_;
1364
+
1365
+ IValue value_; // when finished the value
1366
+ TypePtr type_;
1367
+ std::vector<FutureCallback> callbacks_;
1368
+ std::exception_ptr eptr_;
1369
+
1370
+ // An upcast pointer to a virtual class which allows us to manipulate events,
1371
+ // streams, ... in a generic way, without an explicit dependency on CUDA.
1372
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1373
+ const c10::impl::VirtualGuardImpl impl_;
1374
+
1375
+ // The device that was current when markCompleted was called, which we'll
1376
+ // restore when invoking callbacks. It's optional because we'll only store it
1377
+ // if the future completes successfully.
1378
+ std::optional<c10::Device> currentDevice_;
1379
+
1380
+ // The events that correspond to the completion of the async I/O kernels. They
1381
+ // are recorded on the appropriate streams when the future is marked completed
1382
+ // and can then be queried/waited/blocked on. There is one event for each
1383
+ // distinct device on which the value's tensors reside.
1384
+ std::vector<c10::Event> events_;
1385
+
1386
+ // A cached version of the storages extracted from the value when the future
1387
+ // is first marked completed.
1388
+ std::vector<WeakStorage> storages_;
1389
+
1390
+ // The bounding set of devices that this future, and any of its children, is
1391
+ // allowed to use. This is a superset of the set of devices used by the events
1392
+ // above. We need this to know what streams (for which devices) to set as
1393
+ // current when invoking a callback, thus allowing the callback to use devices
1394
+ // that the parent future didn't use. This field is set to the value provided
1395
+ // in the constructor and will be "inherited" by all child futures.
1396
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1397
+ const std::vector<c10::Device> devices_;
1398
+ };
1399
+
1400
+ struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target {
1401
+ private:
1402
+ explicit Await(TypePtr elType, std::function<IValue()> fn)
1403
+ : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {}
1404
+
1405
+ explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { }
1406
+
1407
+ friend c10::intrusive_ptr<Await>;
1408
+
1409
+ public:
1410
+ Await(const Await&) = delete;
1411
+ Await(Await&&) = delete;
1412
+ Await& operator=(const Await&) = delete;
1413
+ Await& operator=(Await&&) = delete;
1414
+
1415
+ IValue wait() {
1416
+ if (!completed_) {
1417
+ TORCH_CHECK(fn_, "Incompleted Await: fn can't be None");
1418
+ value_ = fn_();
1419
+ completed_ = true;
1420
+ args_ = {};
1421
+ }
1422
+ return value_;
1423
+ }
1424
+
1425
+ IValue value() {
1426
+ TORCH_CHECK(completed_, "Await must be completed");
1427
+ return value_;
1428
+ }
1429
+
1430
+ void setFn(std::function<IValue()> fn) {
1431
+ fn_ = std::move(fn);
1432
+ }
1433
+
1434
+ bool completed() {
1435
+ return completed_;
1436
+ }
1437
+
1438
+ void markCompleted(IValue value) {
1439
+ value_ = std::move(value);
1440
+ completed_ = true;
1441
+ }
1442
+
1443
+ TORCH_API friend std::ostream& operator<<(
1444
+ std::ostream& out,
1445
+ const Await& v);
1446
+
1447
+ const TypePtr& elementType() const {
1448
+ return elType_;
1449
+ }
1450
+
1451
+ const TypePtr& type() const {
1452
+ return type_;
1453
+ }
1454
+
1455
+ void setArgs(std::vector<IValue> args) {
1456
+ args_ = std::move(args);
1457
+ }
1458
+
1459
+ std::vector<IValue>& args() {
1460
+ return args_;
1461
+ }
1462
+
1463
+ private:
1464
+ TypePtr elType_;
1465
+ TypePtr type_;
1466
+ std::vector<IValue> args_;
1467
+ std::function<IValue()> fn_;
1468
+ IValue value_;
1469
+ bool completed_{};
1470
+ };
1471
+
1472
+ // Input is a list of Futures with the same target type.
1473
+ // Output is a Future to the List of completed Futures.
1474
+ TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1475
+ const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1476
+ // Input is a List of Futures with the same target type.
1477
+ // Output is a Future that will be updated with a seen value.
1478
+ TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1479
+ const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1480
+
1481
+ // User-defined object.
1482
+ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
1483
+ public:
1484
+ // In general, class types hold a shared_ptr to its owning CompilationUnit,
1485
+ // so that its type and methods do not get deallocated while the class exists.
1486
+ // However, the CompilationUnit holds ownership of the type's graphs, so
1487
+ // inserting a constant object into a Graph would create a reference cycle if
1488
+ // that constant object held a shared_ptr to its CU. For these objects we
1489
+ // instatiate them with non-owning references to its CU
1490
+ Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
1491
+ slots_.resize(numSlots);
1492
+ }
1493
+
1494
+ Object(StrongTypePtr type, size_t numSlots)
1495
+ : type_(WeakOrStrongTypePtr(std::move(type))) {
1496
+ slots_.resize(numSlots);
1497
+ }
1498
+
1499
+ static c10::intrusive_ptr<Object> create(
1500
+ WeakOrStrongTypePtr type,
1501
+ size_t numSlots) {
1502
+ return c10::make_intrusive<Object>(std::move(type), numSlots);
1503
+ }
1504
+
1505
+ static c10::intrusive_ptr<Object> create(
1506
+ StrongTypePtr type,
1507
+ size_t numSlots) {
1508
+ return c10::make_intrusive<Object>(std::move(type), numSlots);
1509
+ }
1510
+
1511
+ static c10::intrusive_ptr<Object> create(ClassTypePtr classType, size_t numSlots);
1512
+
1513
+ /**
1514
+ * Slot API.
1515
+ *
1516
+ * Attributes are stored as a simple vector so that lookups are fast at
1517
+ * runtime. A "slot" is just an index into that vector, which can be computed
1518
+ * statically if you have access to the class type. Use this API if you are
1519
+ * writing compiler stuff.
1520
+ */
1521
+ void setSlot(size_t slot, IValue v) {
1522
+ if (slot >= slots_.size()) {
1523
+ // for module types, it is possible that the members of the class have
1524
+ // expanded after the object was created. In this case, we expand
1525
+ // the slots to the right size
1526
+ resizeObject(slot);
1527
+ }
1528
+ slots_[slot] = std::move(v);
1529
+ }
1530
+
1531
+ const IValue& getSlot(size_t slot) const {
1532
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size());
1533
+ // NOTE: This lookup is fairly hot, so we use unchecked access to the
1534
+ // vector. Errors should still be detectable with ASan.
1535
+ return slots_[slot];
1536
+ }
1537
+
1538
+ void unsafeRemoveSlot(size_t slot) {
1539
+ TORCH_CHECK(slot < slots_.size());
1540
+ slots_.erase(slots_.begin() + static_cast<std::ptrdiff_t>(slot));
1541
+ }
1542
+
1543
+ /**
1544
+ * Attribute API.
1545
+ *
1546
+ * Wrappers around the slot stuff so that users can access attributes
1547
+ * directly. Use this API if you are a user.
1548
+ *
1549
+ * Note: Unlike in Python, TorchScript must make a distinction between
1550
+ * attributes (which are IValues) and methods (which are Methods). If you
1551
+ * want a method, use `obj.type()->getMethod()`
1552
+ */
1553
+ IValue getAttr(const std::string& name) const;
1554
+ void setAttr(const std::string& name, IValue v);
1555
+ // Remove attribute by name, caller is responsible for
1556
+ // the safety of this operation
1557
+ // We didn't remove the attribute in the type because the type
1558
+ // might be shared by multiple objects.
1559
+ // Therefore after removing attribute, the object is in an inconsistent
1560
+ // state where it has more attribute types in its Type than
1561
+ // the attribute slots it has, user needs to make sure the object
1562
+ // has consistent by removing the attribute in type as well
1563
+ void unsafeRemoveAttr(const std::string& name);
1564
+
1565
+ std::string name() const;
1566
+
1567
+ const std::vector<IValue>& slots() const {
1568
+ return slots_;
1569
+ }
1570
+ std::shared_ptr<ClassType> type() const;
1571
+
1572
+ std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
1573
+ if (type_.holds_strong_ref()) {
1574
+ return type_.cu_.getStrongRefOrThrow();
1575
+ } else {
1576
+ auto weak_ptr = type_.cu_.getWeakRefOrThrow();
1577
+ return std::shared_ptr<torch::jit::CompilationUnit>(weak_ptr);
1578
+ }
1579
+ }
1580
+
1581
+ c10::intrusive_ptr<Object> copy_to_weak_compilation_ref() const;
1582
+
1583
+ void unsafe_make_weak_compilation_ref() {
1584
+ type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr());
1585
+ }
1586
+
1587
+ c10::intrusive_ptr<Object> copy() const;
1588
+
1589
+ c10::intrusive_ptr<Object> deepcopy(
1590
+ std::optional<at::Device> device = std::nullopt) const;
1591
+
1592
+ c10::intrusive_ptr<Object> deepcopy(
1593
+ IValue::HashIdentityIValueMap& memo,
1594
+ std::optional<at::Device> device = std::nullopt) const;
1595
+
1596
+ bool is_weak_compilation_ref() const {
1597
+ return !type_.holds_strong_ref();
1598
+ }
1599
+
1600
+ bool is_empty_strong_compilation_ref() const {
1601
+ return type_.holds_empty_strong_ref();
1602
+ }
1603
+
1604
+ private:
1605
+ void resizeObject(size_t slot);
1606
+ WeakOrStrongTypePtr type_;
1607
+ std::vector<IValue> slots_;
1608
+ };
1609
+
1610
+ // virtual ivalue PyObjectHolder that hold a py::object, we make this virtual
1611
+ // because the py::object and refcounting logic should happen in libtorch_python
1612
+ // see concrete implementation in python_ivalue.h
1613
+ struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
1614
+ public:
1615
+ virtual PyObject* getPyObject() = 0;
1616
+ virtual c10::InferredType tryToInferType() = 0;
1617
+ virtual IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt) = 0;
1618
+ virtual std::string toStr() = 0;
1619
+ virtual std::vector<at::Tensor> extractTensors() = 0;
1620
+
1621
+ ~PyObjectHolder() override = default;
1622
+ };
1623
+
1624
+ struct ivalue::EnumHolder : c10::intrusive_ptr_target {
1625
+ public:
1626
+ EnumHolder(std::shared_ptr<EnumType> type, std::string name, IValue value)
1627
+ : type_(std::move(type)),
1628
+ name_(std::move(name)),
1629
+ value_(std::move(value)) {}
1630
+
1631
+ bool is(const ivalue::EnumHolder& rhs) {
1632
+ return *this == rhs;
1633
+ }
1634
+
1635
+ friend bool operator==(
1636
+ const ivalue::EnumHolder& lhs,
1637
+ const ivalue::EnumHolder& rhs);
1638
+
1639
+ TORCH_API friend std::ostream& operator<<(
1640
+ std::ostream& out,
1641
+ const ivalue::EnumHolder& v);
1642
+
1643
+ TORCH_API const std::string& qualifiedClassName() const;
1644
+
1645
+ const std::string& unqualifiedClassName() const;
1646
+
1647
+ const std::string& name() const {
1648
+ return name_;
1649
+ }
1650
+
1651
+ const IValue& value() const {
1652
+ return value_;
1653
+ }
1654
+
1655
+ std::shared_ptr<EnumType> type() const {
1656
+ return type_;
1657
+ }
1658
+
1659
+ private:
1660
+ std::shared_ptr<EnumType> type_;
1661
+ std::string name_;
1662
+ IValue value_;
1663
+ };
1664
+
1665
+ #undef TORCH_FORALL_TAGS
1666
+
1667
+ namespace detail {
1668
+
1669
+ struct _guarded_unsigned_long_unique_dummy final {
1670
+ _guarded_unsigned_long_unique_dummy(int64_t){};
1671
+ };
1672
+ using _guarded_unsigned_long = std::conditional_t<
1673
+ std::is_same_v<unsigned long, uint32_t> ||
1674
+ std::is_same_v<unsigned long, uint64_t>,
1675
+ _guarded_unsigned_long_unique_dummy,
1676
+ unsigned long>;
1677
+
1678
+ } // namespace detail
1679
+
1680
+ inline ivalue::Object& IValue::toObjectRef() const {
1681
+ AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
1682
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
1683
+ return *static_cast<c10::ivalue::Object*>(payload.u.as_intrusive_ptr);
1684
+ }
1685
+
1686
+ // note: when adding a DEFINE_TO case here you should also add a
1687
+ // toX method to IValue. These named methods are much more discoverable
1688
+ // than the to templated function.
1689
+
1690
+ #define DEFINE_TO(T, method_name) \
1691
+ template <> \
1692
+ inline T IValue::to<T>()&& { \
1693
+ return static_cast<T>(std::move(*this).method_name()); \
1694
+ } \
1695
+ template <> \
1696
+ inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \
1697
+ typedef c10::detail::ivalue_to_const_ref_overload_return<T>::type return_type; \
1698
+ return static_cast<return_type>(this->method_name()); \
1699
+ }
1700
+
1701
+ DEFINE_TO(at::Tensor, toTensor)
1702
+ DEFINE_TO(at::Storage, toStorage)
1703
+ DEFINE_TO(c10::Stream, toStream)
1704
+ DEFINE_TO(float, toDouble)
1705
+ DEFINE_TO(double, toDouble)
1706
+ DEFINE_TO(c10::complex<double>, toComplexDouble)
1707
+ DEFINE_TO(unsigned char, toInt)
1708
+ DEFINE_TO(signed char, toInt)
1709
+ DEFINE_TO(unsigned short, toInt)
1710
+ DEFINE_TO(short, toInt)
1711
+ DEFINE_TO(int, toInt)
1712
+ DEFINE_TO(uint32_t, toInt)
1713
+ DEFINE_TO(uint64_t, toInt)
1714
+ DEFINE_TO(detail::_guarded_unsigned_long, toInt)
1715
+ DEFINE_TO(int64_t, toInt)
1716
+ DEFINE_TO(bool, toBool)
1717
+ DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob);
1718
+ DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
1719
+ DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
1720
+ DEFINE_TO(at::Scalar, toScalar)
1721
+ DEFINE_TO(c10::List<int64_t>, toIntList)
1722
+ DEFINE_TO(c10::List<double>, toDoubleList)
1723
+ DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList)
1724
+ DEFINE_TO(c10::List<bool>, toBoolList)
1725
+ DEFINE_TO(c10::List<at::Tensor>, toTensorList)
1726
+ DEFINE_TO(c10::impl::GenericList, toList)
1727
+ DEFINE_TO(c10::impl::GenericDict, toGenericDict)
1728
+ DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
1729
+ DEFINE_TO(std::string, toStringRef)
1730
+ DEFINE_TO(c10::string_view, toStringView)
1731
+ DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
1732
+ DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait)
1733
+ DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
1734
+ DEFINE_TO(c10::intrusive_ptr<at::Quantizer>, toQuantizer)
1735
+ DEFINE_TO(IValue, toIValue)
1736
+ DEFINE_TO(c10::Device, toDevice)
1737
+ DEFINE_TO(at::ScalarType, toScalarType)
1738
+ DEFINE_TO(at::Layout, toLayout)
1739
+ DEFINE_TO(at::MemoryFormat, toMemoryFormat)
1740
+ DEFINE_TO(at::QScheme, toQScheme)
1741
+ DEFINE_TO(at::Dimname, toDimname)
1742
+ DEFINE_TO(at::Generator, toGenerator)
1743
+ DEFINE_TO(c10::SymInt, toSymInt)
1744
+ DEFINE_TO(c10::SymFloat, toSymFloat)
1745
+ DEFINE_TO(c10::SymBool, toSymBool)
1746
+
1747
+ template <class T>
1748
+ struct _fake_type {};
1749
+
1750
+ // generic_to<T> converts an IValue from a generic list or generic dict
1751
+ // to a concrete list/dict type likelike List<T>, Dict<...> or std::optional<T>.
1752
+ // Note that in the case of lists, this only works for IValue-based lists,
1753
+ // i.e. not for int64_t, double, ...
1754
+ // generic_to<T> is an implementation detail of IValue::to<T> and not
1755
+ // supposed to be called directly.
1756
+ // The _fake_type<T> parameter allows us to overload
1757
+ // based on the return type.
1758
+ template <class Elem>
1759
+ // TODO this is deprecated but we don't throw a warning because a lot of ops in
1760
+ // native_functions.yaml still return std::vector.
1761
+ // C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow
1762
+ // and deprecated. Please use torch::List<T> instead.")
1763
+ std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) {
1764
+ // We need to do a deep copy of the vector because there might be other
1765
+ // references to this same IValue that also use the list. We can't just
1766
+ // move the elements out.
1767
+ auto list = std::move(ivalue).to<List<Elem>>();
1768
+ std::vector<Elem> result;
1769
+ result.reserve(list.size());
1770
+ for (Elem v : list) {
1771
+ result.push_back(std::move(v));
1772
+ }
1773
+ return result;
1774
+ }
1775
+
1776
+ template <typename T>
1777
+ c10::intrusive_ptr<T> IValue::toCustomClass() && {
1778
+ static_assert(
1779
+ std::is_base_of<torch::CustomClassHolder, T>::value == true,
1780
+ "toCustomClass requires that template parameter T must inherit "
1781
+ "from torch::CustomClassHolder");
1782
+ auto obj = toObject();
1783
+ TORCH_CHECK(
1784
+ obj->slots().size() == 1,
1785
+ "Tried to cast IValue to custom class but it did "
1786
+ "not contain a custom class!");
1787
+ const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1788
+ ivalue::checkCustomClassType(expected_type, type().get());
1789
+ auto userObj =
1790
+ c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1791
+ return userObj;
1792
+ }
1793
+
1794
+ template <typename T>
1795
+ c10::intrusive_ptr<T> IValue::toCustomClass() const& {
1796
+ static_assert(
1797
+ std::is_base_of<torch::CustomClassHolder, T>::value == true,
1798
+ "toCustomClass requires that template parameter T must inherit "
1799
+ "from torch::CustomClassHolder");
1800
+ auto obj = toObject();
1801
+ TORCH_CHECK(
1802
+ obj->slots().size() == 1,
1803
+ "Tried to cast IValue to custom class but it did "
1804
+ "not contain a custom class!");
1805
+ const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1806
+ ivalue::checkCustomClassType(expected_type, type().get());
1807
+ auto userObj =
1808
+ c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1809
+ return userObj;
1810
+ }
1811
+
1812
+ template <typename T>
1813
+ T generic_to(IValue ivalue, _fake_type<T>) {
1814
+ using ElemType = typename std::remove_pointer<T>::type::element_type;
1815
+ return std::move(ivalue).toCustomClass<ElemType>();
1816
+ }
1817
+
1818
+ template <typename T>
1819
+ tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) {
1820
+ return tagged_capsule<T>{std::move(ivalue)};
1821
+ }
1822
+
1823
+ template <typename Elem>
1824
+ c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
1825
+ return impl::toTypedList<Elem>(std::move(ivalue).toList());
1826
+ }
1827
+
1828
+ template <typename T>
1829
+ static T createVectorLikeFromList(const c10::detail::ListImpl* impl) {
1830
+ T result;
1831
+ result.reserve(impl->list.size());
1832
+ for (const auto & i : impl->list) {
1833
+ result.push_back(i.to<typename T::value_type>());
1834
+ }
1835
+ return result;
1836
+ }
1837
+
1838
+ template <typename T>
1839
+ static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
1840
+ return createVectorLikeFromList<std::vector<T>>(impl);
1841
+ }
1842
+
1843
+ template <typename T>
1844
+ std::vector<T> createVectorFromList(const c10::List<T>& impl) {
1845
+ std::vector<T> result;
1846
+ result.reserve(impl.size());
1847
+ for (size_t i = 0, N = impl.size(); i < N; ++i) {
1848
+ result.push_back(impl[i]);
1849
+ }
1850
+ return result;
1851
+ }
1852
+
1853
+ template <typename T>
1854
+ OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
1855
+ if (ivalue.isNone()) {
1856
+ return {};
1857
+ }
1858
+ return createVectorFromList<T>(
1859
+ std::move(ivalue).to<c10::List<T>>()
1860
+ );
1861
+ }
1862
+
1863
+ namespace detail {
1864
+ template <typename Elem, size_t... I>
1865
+ std::array<Elem, sizeof...(I)> generic_to_array(
1866
+ IValue ivalue,
1867
+ _fake_type<std::array<Elem, sizeof...(I)>>,
1868
+ std::index_sequence<I...>) {
1869
+ // We need to do a deep copy of the array because there might be other
1870
+ // references to this same IValue that also use the list. We can't just
1871
+ // move the elements out.
1872
+ auto list = std::move(ivalue).to<List<Elem>>();
1873
+ TORCH_CHECK(
1874
+ list.size() == sizeof...(I),
1875
+ "Tried to convert a List with ",
1876
+ list.size(),
1877
+ " elements to a fixed-size array of size ",
1878
+ sizeof...(I));
1879
+ return {list[I]...};
1880
+ }
1881
+ } // namespace detail
1882
+
1883
+ template <typename Elem, size_t N>
1884
+ std::array<Elem, N> generic_to(
1885
+ IValue ivalue,
1886
+ _fake_type<std::array<Elem, N>> ft) {
1887
+ return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
1888
+ }
1889
+
1890
+ template <typename Key, typename Value>
1891
+ c10::Dict<Key, Value> generic_to(
1892
+ IValue ivalue,
1893
+ _fake_type<c10::Dict<Key, Value>>) {
1894
+ return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict());
1895
+ }
1896
+
1897
+ template <typename K, typename V>
1898
+ C10_DEPRECATED_MESSAGE(
1899
+ "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead.")
1900
+ std::unordered_map<K, V> generic_to(
1901
+ IValue ivalue,
1902
+ _fake_type<std::unordered_map<K, V>>) {
1903
+ std::unordered_map<K, V> specialized_dict;
1904
+
1905
+ for (const auto& item : std::move(ivalue).toGenericDict()) {
1906
+ specialized_dict[item.key().template to<K>()] = item.value().template to<V>();
1907
+ }
1908
+
1909
+ return specialized_dict;
1910
+ }
1911
+
1912
+ template <typename T>
1913
+ std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>>) {
1914
+ if (ivalue.isNone()) {
1915
+ return std::nullopt;
1916
+ }
1917
+ return std::move(ivalue).to<T>();
1918
+ }
1919
+
1920
+ namespace detail {
1921
+ template <typename Tuple, std::size_t... INDEX>
1922
+ Tuple generic_to_tuple_impl(
1923
+ const ivalue::TupleElements& t,
1924
+ std::index_sequence<INDEX...>) {
1925
+ return std::make_tuple(
1926
+ t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...);
1927
+ }
1928
+ } // namespace detail
1929
+
1930
+ template <
1931
+ typename... Args,
1932
+ typename Indices = std::make_index_sequence<sizeof...(Args)>,
1933
+ std::enable_if_t<
1934
+ !std::disjunction_v<
1935
+ std::is_lvalue_reference<Args>...,
1936
+ std::negation<std::is_constructible<IValue, Args>>...>,
1937
+ std::nullptr_t> = nullptr>
1938
+ std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
1939
+ const auto& vals = ivalue.toTupleRef().elements();
1940
+ TORCH_CHECK(vals.size() == sizeof...(Args));
1941
+ return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
1942
+ }
1943
+
1944
+ template <typename T>
1945
+ inline T IValue::to() && {
1946
+ return generic_to(std::move(*this), _fake_type<T>{});
1947
+ }
1948
+
1949
+ template <>
1950
+ inline std::optional<c10::string_view> IValue::to() && {
1951
+ // In the default implementation, the IValue is destroyed with std::move.
1952
+ // But if the unboxed type is std::optional<string_view> we cannot destroy
1953
+ // the IValue.
1954
+ return generic_to(*this, _fake_type<std::optional<c10::string_view>>{});
1955
+ }
1956
+
1957
+ template <typename T>
1958
+ inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& {
1959
+ return generic_to(*this, _fake_type<T>{});
1960
+ }
1961
+
1962
+ inline c10::List<int64_t> IValue::toIntList() && {
1963
+ AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1964
+ return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
1965
+ }
1966
+ inline c10::List<int64_t> IValue::toIntList() const& {
1967
+ AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1968
+ return c10::List<int64_t>(toIntrusivePtr<c10::detail::ListImpl>());
1969
+ }
1970
+ inline std::vector<int64_t> IValue::toIntVector() const {
1971
+ AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1972
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1973
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1974
+ "called toIntVector on null intrusive_ptr IValue");
1975
+ return createVectorFromList<int64_t>(
1976
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1977
+ }
1978
+ inline std::vector<c10::SymInt> IValue::toSymIntVector() const {
1979
+ AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind());
1980
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1981
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1982
+ "called toSymIntVector on null intrusive_ptr IValue");
1983
+ return createVectorFromList<c10::SymInt>(
1984
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1985
+ }
1986
+ inline at::DimVector IValue::toDimVector() const {
1987
+ AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1988
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1989
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1990
+ "called toDimVector on null intrusive_ptr IValue");
1991
+ return createVectorLikeFromList<at::DimVector>(
1992
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1993
+ }
1994
+ inline c10::List<double> IValue::toDoubleList() && {
1995
+ AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1996
+ return c10::List<double>(moveToIntrusivePtr<c10::detail::ListImpl>());
1997
+ }
1998
+ inline c10::List<double> IValue::toDoubleList() const& {
1999
+ AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2000
+ return c10::List<double>(toIntrusivePtr<c10::detail::ListImpl>());
2001
+ }
2002
+ inline std::vector<double> IValue::toDoubleVector() const {
2003
+ AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2004
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2005
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2006
+ "called toDoubleVector on null intrusive_ptr IValue");
2007
+ return createVectorFromList<double>(
2008
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2009
+ }
2010
+ inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() && {
2011
+ AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2012
+ return c10::List<c10::complex<double>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2013
+ }
2014
+ inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() const& {
2015
+ AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2016
+ return c10::List<c10::complex<double>>(toIntrusivePtr<c10::detail::ListImpl>());
2017
+ }
2018
+ inline std::vector<c10::complex<double>> IValue::toComplexDoubleVector() const {
2019
+ AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2020
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2021
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2022
+ "called toComplexDoubleVector on null intrusive_ptr IValue");
2023
+ return createVectorFromList<c10::complex<double>>(
2024
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2025
+ }
2026
+ inline c10::List<bool> IValue::toBoolList() && {
2027
+ AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2028
+ return c10::List<bool>(moveToIntrusivePtr<c10::detail::ListImpl>());
2029
+ }
2030
+ inline c10::List<bool> IValue::toBoolList() const& {
2031
+ AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2032
+ return c10::List<bool>(toIntrusivePtr<c10::detail::ListImpl>());
2033
+ }
2034
+ inline c10::List<at::Tensor> IValue::toTensorList() && {
2035
+ AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2036
+ return c10::List<at::Tensor>(moveToIntrusivePtr<c10::detail::ListImpl>());
2037
+ }
2038
+ inline c10::List<at::Tensor> IValue::toTensorList() const& {
2039
+ AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2040
+ return c10::List<at::Tensor>(toIntrusivePtr<c10::detail::ListImpl>());
2041
+ }
2042
+ inline std::vector<at::Tensor> IValue::toTensorVector() const {
2043
+ AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2044
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2045
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2046
+ "called toTensorVector on null intrusive_ptr IValue");
2047
+ return createVectorFromList<at::Tensor>(
2048
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2049
+ }
2050
+ inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() && {
2051
+ AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2052
+ return c10::List<std::optional<at::Tensor>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2053
+ }
2054
+ inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() const& {
2055
+ AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2056
+ return c10::List<std::optional<at::Tensor>>(toIntrusivePtr<c10::detail::ListImpl>());
2057
+ }
2058
+ inline std::vector<std::optional<at::Tensor>> IValue::toOptionalTensorVector() const {
2059
+ AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2060
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2061
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2062
+ "called toOptionalTensorVector on null intrusive_ptr IValue");
2063
+ return createVectorFromList<std::optional<at::Tensor>>(
2064
+ static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2065
+ }
2066
+ inline c10::List<IValue> IValue::toList() && {
2067
+ AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2068
+ return c10::List<IValue>(moveToIntrusivePtr<c10::detail::ListImpl>());
2069
+ }
2070
+ inline c10::List<IValue> IValue::toList() const& {
2071
+ AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2072
+ return c10::List<IValue>(toIntrusivePtr<c10::detail::ListImpl>());
2073
+ }
2074
+ inline c10::ArrayRef<IValue> IValue::toListRef() const {
2075
+ AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2076
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2077
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2078
+ "called toListRef on null intrusive_ptr IValue");
2079
+ return static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)
2080
+ ->list;
2081
+ }
2082
+ inline c10::Dict<IValue, IValue> IValue::toGenericDict() && {
2083
+ AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2084
+ return c10::Dict<IValue, IValue>(moveToIntrusivePtr<c10::detail::DictImpl>());
2085
+ }
2086
+ inline c10::Dict<IValue, IValue> IValue::toGenericDict() const& {
2087
+ AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2088
+ return c10::Dict<IValue, IValue>(toIntrusivePtr<c10::detail::DictImpl>());
2089
+ }
2090
+ inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && {
2091
+ AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2092
+ return moveToIntrusivePtr<ivalue::Tuple>();
2093
+ }
2094
+ inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& {
2095
+ AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2096
+ return toIntrusivePtr<ivalue::Tuple>();
2097
+ }
2098
+ inline ivalue::Tuple& IValue::toTupleRef() const {
2099
+ AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2100
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2101
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2102
+ "called toTupleRef on null intrusive_ptr IValue");
2103
+ return *static_cast<c10::ivalue::Tuple*>(
2104
+ payload.u.as_intrusive_ptr);
2105
+ }
2106
+
2107
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
2108
+ : tag(Tag::Tuple) {
2109
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2110
+ }
2111
+ template <
2112
+ typename... Args,
2113
+ std::enable_if_t<
2114
+ !std::disjunction_v<
2115
+ std::is_lvalue_reference<Args>...,
2116
+ std::negation<std::is_constructible<IValue, Args>>...>,
2117
+ std::nullptr_t>>
2118
+ inline IValue::IValue(const std::tuple<Args...>& t)
2119
+ : IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
2120
+ }
2121
+
2122
+ template <
2123
+ typename... Args,
2124
+ std::enable_if_t<
2125
+ !std::disjunction_v<
2126
+ std::is_lvalue_reference<Args>...,
2127
+ std::negation<std::is_constructible<IValue, Args>>...>,
2128
+ std::nullptr_t>>
2129
+ inline IValue::IValue(std::tuple<Args...>&& t)
2130
+ : IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
2131
+ }
2132
+
2133
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
2134
+ : tag(Tag::String) {
2135
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2136
+ }
2137
+ inline IValue::IValue(std::string v)
2138
+ : IValue(ivalue::ConstantString::create(std::move(v))) {}
2139
+
2140
+ inline IValue::IValue(c10::impl::GenericList v)
2141
+ : tag(Tag::GenericList) {
2142
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2143
+ }
2144
+
2145
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2146
+ inline IValue::IValue(c10::List<T>&& v) : IValue(impl::toList<T>(std::move(v))) {}
2147
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2148
+ inline IValue::IValue(const c10::List<T>& v) : IValue(impl::toList<T>(v)) {}
2149
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2150
+ inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
2151
+ auto list = to<c10::List<T>>();
2152
+ list.reserve(v.size());
2153
+ for (const auto& e : v) {
2154
+ list.push_back(e);
2155
+ }
2156
+ }
2157
+ template <class T, IValue::enable_if_symint<T>>
2158
+ inline IValue::IValue(at::ArrayRef<T> v) : IValue() {
2159
+ auto vi = c10::asIntArrayRefSlowOpt(v);
2160
+ if (vi.has_value()) {
2161
+ // This list is entirely integers; ensure it is typed as
2162
+ // an IntList so toIntList works
2163
+ *this = IValue(*vi);
2164
+ } else {
2165
+ // This list has SymInts; type it as a SymInt
2166
+ *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2167
+ auto list = to<c10::List<c10::SymInt>>();
2168
+ list.reserve(v.size());
2169
+ for (const auto& e : v) {
2170
+ list.push_back(e);
2171
+ }
2172
+ }
2173
+ }
2174
+ template <class T, IValue::enable_if_symint<T>>
2175
+ inline IValue::IValue(at::OptionalArrayRef<T> mb_v) : IValue() {
2176
+ if (!mb_v.has_value()) return;
2177
+ *this = IValue(*mb_v);
2178
+ }
2179
+ template <class T, IValue::enable_if_symint<T>>
2180
+ inline IValue::IValue(const std::vector<T>& v) : IValue() {
2181
+ *this = IValue(at::ArrayRef<T>(v));
2182
+ }
2183
+ template <class T, IValue::enable_if_symint<T>>
2184
+ inline IValue::IValue(std::vector<T>&& v) : IValue() {
2185
+ auto vi = c10::asIntArrayRefSlowOpt(v);
2186
+ if (vi.has_value()) {
2187
+ // This list is entirely integers; ensure it is typed as
2188
+ // an IntList so toIntList works
2189
+ *this = IValue(*vi);
2190
+ } else {
2191
+ // This list has SymInts; type it as a SymInt
2192
+ *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2193
+ auto list = to<c10::List<c10::SymInt>>();
2194
+ list.reserve(v.size());
2195
+ for (auto&& e : std::move(v)) {
2196
+ list.push_back(std::move(e));
2197
+ }
2198
+ }
2199
+ }
2200
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2201
+ inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
2202
+ auto list = to<c10::List<T>>();
2203
+ list.reserve(v.size());
2204
+ for (const auto& e : v) {
2205
+ list.push_back(e);
2206
+ }
2207
+ }
2208
+
2209
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2210
+ inline IValue::IValue(std::vector<T>&& v) : IValue(c10::List<T>()) {
2211
+ auto list = to<c10::List<T>>();
2212
+ list.reserve(v.size());
2213
+ if constexpr (std::is_same_v<T, bool>) {
2214
+ for (auto e : v) {
2215
+ list.push_back(e);
2216
+ }
2217
+ } else {
2218
+ for (auto&& e : std::move(v)) {
2219
+ list.push_back(std::move(e));
2220
+ }
2221
+ }
2222
+ }
2223
+
2224
+ template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2225
+ inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() {
2226
+ if (v.has_value()) {
2227
+ *this = IValue(std::move(*v));
2228
+ }
2229
+ }
2230
+
2231
+ template <class T, size_t N>
2232
+ inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) {
2233
+ auto list = to<c10::List<T>>();
2234
+ list.reserve(v.size());
2235
+ for (auto& e : v) {
2236
+ list.push_back(std::move(e));
2237
+ }
2238
+ }
2239
+
2240
+ template <class T, IValue::enable_if_ilist_is_ivalue_constructible<T>>
2241
+ inline IValue::IValue(c10::IListRef<T> v) : IValue() {
2242
+ constexpr bool boxed_type_constructs_ivalue =
2243
+ std::is_constructible<IValue, typename c10::IListRef<T>::boxed_type>::value;
2244
+ // First, we try to use the boxed value.
2245
+ // If we fail (either it's not in the boxed state, or its boxed type
2246
+ // can not construct an IValue), we fallback to copying the list.
2247
+ if (boxed_type_constructs_ivalue && v.isBoxed()) {
2248
+ *this = IValue(impl::toList(v.toBoxed()));
2249
+ } else {
2250
+ c10::List<T> list;
2251
+ list.reserve(v.size());
2252
+ for (const auto& t : v) {
2253
+ list.push_back(t);
2254
+ }
2255
+ *this = IValue(impl::toList(std::move(list)));
2256
+ }
2257
+ }
2258
+
2259
+ inline IValue::IValue(c10::impl::GenericDict v)
2260
+ : tag(Tag::GenericDict) {
2261
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2262
+ }
2263
+ template <class Key, class Value>
2264
+ inline IValue::IValue(c10::Dict<Key, Value> v)
2265
+ : IValue(impl::toGenericDict(std::move(v))) {}
2266
+
2267
+ template <class Key, class Value>
2268
+ inline IValue::IValue(std::unordered_map<Key, Value> v)
2269
+ : IValue(Dict<Key, Value>()) {
2270
+ auto dict = to<c10::Dict<Key, Value>>();
2271
+ dict.reserve(v.size());
2272
+ for (auto& e : v) {
2273
+ dict.insert(std::move(e.first), std::move(e.second));
2274
+ }
2275
+ }
2276
+
2277
+ template <class T, IValue::enable_if_ivalue_constructible<T>>
2278
+ inline IValue::IValue(std::optional<T> v) : IValue() {
2279
+ if (v.has_value()) {
2280
+ *this = IValue(std::move(*v));
2281
+ }
2282
+ }
2283
+
2284
+ inline IValue::IValue(std::nullopt_t) : IValue() {}
2285
+
2286
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
2287
+ : tag(Tag::Object) {
2288
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2289
+ }
2290
+
2291
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
2292
+ : tag(Tag::PyObject) {
2293
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2294
+ }
2295
+
2296
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
2297
+ : tag(Tag::Enum) {
2298
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2299
+ }
2300
+
2301
+ inline IValue IValue::make_capsule(
2302
+ intrusive_ptr<torch::CustomClassHolder> blob) {
2303
+ IValue iv;
2304
+ iv.tag = Tag::Capsule;
2305
+ iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
2306
+ return iv;
2307
+ }
2308
+
2309
+ template <
2310
+ typename T,
2311
+ std::enable_if_t<std::is_base_of_v<torch::CustomClassHolder, T>, int>>
2312
+ IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) {
2313
+ auto classType = []() {
2314
+ try {
2315
+ return c10::getCustomClassType<c10::intrusive_ptr<T>>();
2316
+ } catch (const c10::Error&) {
2317
+ throw c10::Error(
2318
+ "Trying to instantiate a class that isn't a registered custom class: " +
2319
+ std::string(c10::util::get_fully_qualified_type_name<T>()));
2320
+ }
2321
+ }();
2322
+ auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);
2323
+ ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
2324
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
2325
+
2326
+ }
2327
+
2328
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
2329
+ : tag(Tag::Future) {
2330
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2331
+ }
2332
+
2333
+ inline IValue::IValue(c10::intrusive_ptr<ivalue::Await> v)
2334
+ : tag(Tag::Await) {
2335
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2336
+ }
2337
+
2338
+ inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
2339
+ : tag(Tag::RRef) {
2340
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2341
+ }
2342
+
2343
+ inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
2344
+ : tag(Tag::Quantizer) {
2345
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2346
+ }
2347
+
2348
+ template <typename T>
2349
+ inline IValue::IValue(c10::complex<T> c)
2350
+ : tag(Tag::ComplexDouble) {
2351
+ auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
2352
+ payload.u.as_intrusive_ptr = v.release();
2353
+ }
2354
+
2355
+ inline const std::string& IValue::toStringRef() const {
2356
+ AT_ASSERT(isString(), "Expected String but got ", tagKind());
2357
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2358
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2359
+ "called toStringRef on null intrusive_ptr IValue");
2360
+ return static_cast<const c10::ivalue::ConstantString*>(
2361
+ payload.u.as_intrusive_ptr)
2362
+ ->string();
2363
+ }
2364
+ inline std::optional<std::reference_wrapper<const std::string>> IValue::
2365
+ toOptionalStringRef() const {
2366
+ if (isNone()) {
2367
+ return std::nullopt;
2368
+ }
2369
+ AT_ASSERT(isString(), "Expected std::optional<string> but got ", tagKind());
2370
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2371
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2372
+ "called toOptionalStringRef on null intrusive_ptr IValue");
2373
+ return std::reference_wrapper<const std::string>(
2374
+ static_cast<const c10::ivalue::ConstantString*>(payload.u.as_intrusive_ptr)
2375
+ ->string());
2376
+ }
2377
+
2378
+ inline c10::string_view IValue::toStringView() const {
2379
+ AT_ASSERT(isString(), "Expected String but got ", tagKind());
2380
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2381
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2382
+ "called toStringView on null intrusive_ptr IValue");
2383
+ return static_cast<const c10::ivalue::ConstantString*>(
2384
+ payload.u.as_intrusive_ptr)
2385
+ ->string_view();
2386
+ }
2387
+
2388
+ inline PyObject* IValue::toPyObject() const {
2389
+ return toPyObjectHolder()->getPyObject();
2390
+ }
2391
+
2392
+ template <typename T>
2393
+ inline std::optional<T> IValue::toOptional() {
2394
+ if (this->isNone()) {
2395
+ return std::nullopt;
2396
+ }
2397
+ return this->to<T>();
2398
+ }
2399
+
2400
+ template <typename T>
2401
+ inline std::optional<T> IValue::toOptional() const {
2402
+ if (this->isNone()) {
2403
+ return std::nullopt;
2404
+ }
2405
+ return this->to<T>();
2406
+ }
2407
+
2408
+ inline bool IValue::isCustomClass() const {
2409
+ return torch::isCustomClass(*this);
2410
+ }
2411
+
2412
+ inline bool IValue::isSameIdentity(const IValue& rhs) const {
2413
+ // We choose to not use memcmp for payload check due to potential random
2414
+ // padding characters on union type
2415
+
2416
+ // Semantics:
2417
+ // 1. Immutable primitive values of the same type (Int, Double, None, Bool,
2418
+ // Str) return value equality
2419
+ // 2. If it is a tensor type, we need to take undefined tensor into account
2420
+ // 3. Undefined_tensor is None and vice versa should be true
2421
+ // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when
2422
+ // the pointed-to object is the same.
2423
+ // 5. False for all other comparisons.
2424
+ if (this->isNone() && rhs.isNone()) {
2425
+ return true;
2426
+ } else if (this->isBool() && rhs.isBool()) {
2427
+ // for bool type, do equality check
2428
+ return this->toBool() == rhs.toBool();
2429
+ } else if (this->isTensor() && rhs.isTensor()) {
2430
+ return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
2431
+ } else if (this->isTensor() && rhs.isNone()) {
2432
+ // special case: undefined tensor and None are the same identity
2433
+ return !this->payload.as_tensor.defined();
2434
+ } else if (this->isNone() && rhs.isTensor()) {
2435
+ // special case: undefined tensor and None are the same identity
2436
+ return !rhs.payload.as_tensor.defined();
2437
+ } else if (this->isInt() && rhs.isInt()) {
2438
+ return this->toInt() == rhs.toInt();
2439
+ } else if (this->isDouble() && rhs.isDouble()) {
2440
+ return this->toDouble() == rhs.toDouble();
2441
+ } else if (this->isString() && rhs.isString()) {
2442
+ return this->toStringRef() == rhs.toStringRef();
2443
+ } else {
2444
+ // for objects holding in IValue, do shallow compare on pointer address to
2445
+ // testify the identity
2446
+ return this->isIntrusivePtr() && rhs.isIntrusivePtr() &&
2447
+ this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
2448
+ }
2449
+ }
2450
+
2451
+ namespace ivalue {
2452
+ namespace detail {
2453
+
2454
+ template <typename T>
2455
+ IValue from_(T&& x, std::true_type) {
2456
+ return IValue(std::forward<T>(x));
2457
+ }
2458
+ template <typename T>
2459
+ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
2460
+ return IValue(std::move(x));
2461
+ }
2462
+ template <typename T>
2463
+ IValue from_(T&& /*x*/, std::false_type) {
2464
+ static_assert(
2465
+ guts::false_t<T>::value,
2466
+ "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
2467
+ return IValue();
2468
+ }
2469
+ } // namespace detail
2470
+
2471
+ template <typename T>
2472
+ IValue from(T&& x) {
2473
+ return detail::from_(
2474
+ std::forward<T>(x), typename std::is_constructible<IValue, T>::type{});
2475
+ }
2476
+
2477
+ } // namespace ivalue
2478
+
2479
+
2480
+ template <>
2481
+ struct MaybeOwnedTraits<IValue> {
2482
+ using owned_type = IValue;
2483
+ using borrow_type = IValue;
2484
+
2485
+ static borrow_type createBorrow(const owned_type& from) {
2486
+ if (!from.isPtrType()) {
2487
+ return from;
2488
+ }
2489
+ if (from.isTensor()) {
2490
+ return IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(from.toTensor()));
2491
+ } else {
2492
+ return IValue(from.payload, from.tag);
2493
+ }
2494
+ }
2495
+
2496
+ static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
2497
+ lhs.clearToNone();
2498
+ if (!rhs.isPtrType()) {
2499
+ lhs = rhs;
2500
+ } else if (rhs.isTensor()) {
2501
+ lhs = IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(rhs.toTensor()));
2502
+ } else {
2503
+ lhs = IValue(rhs.payload, rhs.tag);
2504
+ }
2505
+ }
2506
+
2507
+ static void destroyBorrow(borrow_type& toDestroy) {
2508
+ toDestroy.clearToNone();
2509
+ }
2510
+
2511
+ static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
2512
+ return borrow;
2513
+ }
2514
+
2515
+ static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
2516
+ return &borrow;
2517
+ }
2518
+
2519
+ static bool debugBorrowIsValid(const borrow_type&) {
2520
+ return true;
2521
+ }
2522
+ };
2523
+
2524
+ template <>
2525
+ struct IValue::TagType<c10::Type> {
2526
+ static TORCH_API c10::TypePtr get(const IValue&);
2527
+ };
2528
+
2529
+ template <>
2530
+ struct IValue::TagType<c10::DynamicType> {
2531
+ static TORCH_API c10::TypePtr get(const IValue&);
2532
+ };
2533
+
2534
+ template <typename T>
2535
+ TypePtr IValue::type() const {
2536
+ return IValue::TagType<T>::get(*this);
2537
+ }
2538
+
2539
+ } // namespace c10
.venv/lib/python3.11/site-packages/torch/include/ATen/core/rref_interface.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/intrusive_ptr.h>
4
+ #include <ATen/core/jit_type_base.h>
5
+
6
+ namespace c10 {
7
+
8
+ struct Type;
9
+ using worker_id_t = int16_t;
10
+
11
+ // This abstract class contains only user-facing APIs, and will be shared
12
+ // between jit and distributed to implement TorchScript support.
13
+ class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
14
+ public:
15
+ RRefInterface() = default;
16
+ // RRef is made NOT copyable NOT movable to prevent messing up reference
17
+ // counting.
18
+ RRefInterface(const RRefInterface& other) = delete;
19
+ RRefInterface(RRefInterface&& other) = delete;
20
+ RRefInterface& operator=(RRefInterface&& other) = delete;
21
+
22
+ ~RRefInterface() override = default;
23
+
24
+ // returns the worker id of the owner
25
+ virtual worker_id_t owner() const = 0;
26
+
27
+ // returns the worker name of the owner
28
+ virtual std::string ownerName() const = 0;
29
+
30
+ // Returns true if this is the ``OwnerRRef``
31
+ virtual bool isOwner() const = 0;
32
+
33
+ // Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been
34
+ // confirmed by its owner.
35
+ virtual bool confirmedByOwner() const = 0;
36
+
37
+ virtual const TypePtr type() const = 0;
38
+ };
39
+
40
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/core/stack.h ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <type_traits>
4
+
5
+ #include <ATen/core/ivalue.h>
6
+ #include <c10/util/Deprecated.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ // TODO move this to c10 namespace
10
+
11
+
12
+ namespace torch::jit {
13
+
14
+ using c10::IValue;
15
+ using Stack = std::vector<IValue>;
16
+
17
+ class Operation {
18
+ template <typename F, typename Arg>
19
+ using accepts = std::is_constructible<std::function<void(Arg)>, F&&>;
20
+
21
+ public:
22
+ template <typename F,
23
+ std::enable_if_t<accepts<F, Stack*>::value, int> = 0>
24
+ C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
25
+ // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
26
+ Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) {
27
+ raw(&stack);
28
+ }) {}
29
+
30
+ template <typename F,
31
+ std::enable_if_t<accepts<F, Stack&>::value &&
32
+ !std::is_same_v<std::decay_t<F>, Operation>, int> = 0>
33
+ Operation(F&& op): op_(std::forward<F>(op)) {}
34
+
35
+ Operation(std::nullptr_t) noexcept {}
36
+
37
+ explicit operator bool() const noexcept {
38
+ return op_ ? true : false;
39
+ }
40
+
41
+ void operator()(Stack& stack) {
42
+ op_(stack);
43
+ }
44
+
45
+ template <typename T>
46
+ T* target() noexcept {
47
+ return op_.target<T>();
48
+ }
49
+
50
+ private:
51
+ std::function<void(Stack&)> op_;
52
+ };
53
+
54
+ // An operation with N inputs and M outputs pops the last N inputs off
55
+ // the stack and pushes its M inputs onto the stack
56
+ // before: <other stack items> I0, I1, ... IN <- stack.back()
57
+ // after: <other stack items> O0, O1, ... OM
58
+ // operations are defined this way so that ownership of inputs can be
59
+ // transferred to the operation and it can incrementally drop ownership of
60
+ // tensors when they become unneeded. For large operations, like 'run an entire
61
+ // subgraph', this functionality is very important for minimizing gpu memory
62
+ // usage return value is the relative 'offset' to jump to for the next
63
+ // operation:
64
+ // pc += 1 + offset
65
+ // so a return value of 0 goes to the next instruction
66
+
67
+ // treat the last N elements of the stack as a list, looking up
68
+ // element i
69
+ inline IValue& peek(Stack& stack, size_t i, size_t N) {
70
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
71
+ return *(stack.end() - N + i);
72
+ }
73
+ inline IValue& peek(Stack* stack, size_t i, size_t N) {
74
+ return peek(*stack, i, N);
75
+ }
76
+ inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
77
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
78
+ return *(stack.end() - N + i);
79
+ }
80
+ inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
81
+ return peek(*stack, i, N);
82
+ }
83
+ // treat the last N elements of the stack as a list, looking up the
84
+ // slice starting at index i and having length len
85
+ inline at::ArrayRef<IValue> peekSlice(
86
+ const Stack& stack,
87
+ size_t i,
88
+ size_t len,
89
+ size_t N) {
90
+ return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
91
+ }
92
+ inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
93
+ return peekSlice(stack, 0, N, N);
94
+ }
95
+ inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
96
+ return last(*stack, N);
97
+ }
98
+ inline void drop(Stack& stack, size_t n) {
99
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
100
+ stack.erase(stack.end() - n, stack.end());
101
+ }
102
+ inline void drop(Stack* stack, size_t n) {
103
+ drop(*stack, n);
104
+ }
105
+ inline IValue pop(Stack& stack) {
106
+ auto r = std::move(stack.back());
107
+ stack.pop_back();
108
+ return r;
109
+ }
110
+ inline IValue pop(Stack* stack) {
111
+ return pop(*stack);
112
+ }
113
+ inline std::vector<IValue> pop(Stack& stack, size_t n) {
114
+ std::vector<IValue> result;
115
+ result.reserve(n);
116
+ for (const auto i : c10::irange(n)) {
117
+ result.push_back(std::move(peek(stack, i, n)));
118
+ }
119
+ drop(stack, n);
120
+ return result;
121
+ }
122
+
123
+ // variadic pop:
124
+ // int64_t a; at::Tensor b;
125
+ // pop(stack, a, b);
126
+ // equivalent to:
127
+ // b = pop(stack).toTensor();
128
+ // a = pop(stack).toInt();
129
+ template <typename... Types>
130
+ inline void pop(Stack& stack, Types&... args) {
131
+ size_t i = 0;
132
+ constexpr size_t N = sizeof...(args);
133
+ (void)std::initializer_list<int>{
134
+ (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
135
+ drop(stack, N);
136
+ }
137
+ template <typename... Types>
138
+ inline void pop(Stack* stack, Types&... args) {
139
+ pop(*stack, args...);
140
+ }
141
+ template <typename Type>
142
+ inline void push_one(Stack& stack, Type&& arg) {
143
+ stack.emplace_back(std::forward<Type>(arg));
144
+ }
145
+
146
+ inline void push_one(Stack& stack, c10::TensorOptions options) {
147
+ stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
148
+ stack.emplace_back(options.layout());
149
+ stack.emplace_back(options.device());
150
+ stack.emplace_back(options.pinned_memory());
151
+ }
152
+
153
+ template <typename... Types>
154
+ inline void push(Stack& stack, Types&&... args) {
155
+ (void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
156
+ }
157
+ template <typename... Types>
158
+ inline void push(Stack* stack, Types&&... args) {
159
+ return push(*stack, std::forward<Types>(args)...);
160
+ }
161
+ template <class T>
162
+ inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
163
+ for (T elem : elements) {
164
+ stack.push_back(std::move(elem));
165
+ }
166
+ }
167
+
168
+ // The packer here is carefully written not to make any unnecessary
169
+ // copies.
170
+
171
+ // pack takes the return values of aten functions pushes them onto the stack
172
+ template <typename T>
173
+ inline void pack(Stack& stack, T&& v) {
174
+ stack.emplace_back(std::forward<T>(v));
175
+ }
176
+ template <typename T>
177
+ inline void pack(Stack* stack, T&& v) {
178
+ pack(*stack, std::forward<T>(v));
179
+ }
180
+
181
+ template <std::size_t remaining, typename... Args>
182
+ struct TuplePacker {
183
+ // NB: *Not* a universal reference.
184
+ static void execute(Stack& stack, std::tuple<Args...>&& t) {
185
+ // NB: The move here does not "destroy" the entire tuple, that is
186
+ // not what std::move does; only the particular tuple index
187
+ // processed here gets stolen.
188
+ pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t)));
189
+ TuplePacker<remaining - 1, Args...>::execute(stack, std::move(t));
190
+ }
191
+ };
192
+
193
+ template <typename... Args>
194
+ struct TuplePacker<0, Args...> {
195
+ // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
196
+ static void execute(Stack& /*stack*/, std::tuple<Args...>&& /*t*/){};
197
+ };
198
+
199
+ template <typename... Args>
200
+ inline void pack(Stack& stack, std::tuple<Args...>&& t) {
201
+ TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
202
+ }
203
+
204
+ } // namespace torch::jit
.venv/lib/python3.11/site-packages/torch/include/ATen/core/type_factory.h ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <type_traits>
4
+ #include <unordered_map>
5
+
6
+ #include <ATen/core/dynamic_type.h>
7
+ #include <ATen/core/jit_type_base.h>
8
+ #include <c10/macros/Macros.h>
9
+
10
+ namespace c10 {
11
+
12
+ template <typename T>
13
+ struct TORCH_API TypeFactoryBase {};
14
+
15
+ template <>
16
+ struct TORCH_API TypeFactoryBase<c10::DynamicType> {
17
+ template <typename T, typename... Args>
18
+ static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
19
+ return std::make_shared<c10::DynamicType>(
20
+ c10::DynamicTypeTrait<T>::tagValue(),
21
+ c10::DynamicType::Arguments(c10::ArrayRef<c10::TypePtr>(
22
+ {std::move(ty), std::forward<Args>(args)...})));
23
+ }
24
+ template <typename T>
25
+ static c10::DynamicTypePtr create(const std::vector<c10::TypePtr>& types) {
26
+ return std::make_shared<c10::DynamicType>(
27
+ c10::DynamicTypeTrait<T>::tagValue(),
28
+ c10::DynamicType::Arguments(types));
29
+ }
30
+ static c10::DynamicTypePtr createNamedTuple(
31
+ const std::string& name,
32
+ const std::vector<c10::string_view>& fields,
33
+ const std::vector<c10::TypePtr>& types) {
34
+ return std::make_shared<c10::DynamicType>(
35
+ c10::DynamicType::Tag::Tuple,
36
+ name,
37
+ c10::DynamicType::Arguments(fields, types));
38
+ }
39
+ template <typename T>
40
+ C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
41
+ return std::make_shared<c10::DynamicType>(
42
+ c10::DynamicTypeTrait<T>::tagValue(),
43
+ name,
44
+ c10::DynamicType::Arguments{});
45
+ }
46
+ template <typename T>
47
+ C10_ERASE static c10::DynamicTypePtr get() {
48
+ return DynamicTypeTrait<T>::getBaseType();
49
+ }
50
+ static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
51
+ };
52
+
53
+ using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
54
+
55
+ // Helper functions for constructing DynamicTypes inline.
56
+ template <
57
+ typename T,
58
+ std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
59
+ C10_ERASE DynamicTypePtr dynT() {
60
+ return DynamicTypeFactory::get<T>();
61
+ }
62
+
63
+ template <
64
+ typename T,
65
+ typename... Args,
66
+ std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
67
+ C10_ERASE DynamicTypePtr dynT(Args&&... args) {
68
+ return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
69
+ }
70
+
71
+ template <>
72
+ struct TORCH_API TypeFactoryBase<c10::Type> {
73
+ template <typename T, typename... Args>
74
+ static c10::TypePtr create(TypePtr ty, Args&&... args) {
75
+ return T::create(std::move(ty), std::forward<Args>(args)...);
76
+ }
77
+ template <typename T>
78
+ static c10::TypePtr create(std::vector<c10::TypePtr> types) {
79
+ return T::create(std::move(types));
80
+ }
81
+ static c10::TypePtr createNamedTuple(
82
+ const std::string& name,
83
+ const std::vector<c10::string_view>& fields,
84
+ const std::vector<c10::TypePtr>& types);
85
+ template <typename T>
86
+ C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
87
+ return T::create(name);
88
+ }
89
+ static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
90
+ template <typename T>
91
+ C10_ERASE static c10::TypePtr get() {
92
+ return T::get();
93
+ }
94
+ };
95
+
96
+ using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
97
+
98
+ using PlatformType =
99
+ #ifdef C10_MOBILE
100
+ c10::DynamicType
101
+ #else
102
+ c10::Type
103
+ #endif
104
+ ;
105
+
106
+ using TypeFactory = TypeFactoryBase<PlatformType>;
107
+
108
+ } // namespace c10
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_batch_norm_with_update_native.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_functional(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps);
20
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _batch_norm_with_update_cpu_out(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve);
21
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_cpu(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps);
22
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _batch_norm_with_update_cuda_out(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve);
23
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_cuda(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps);
24
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_mkldnn(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps);
25
+ } // namespace native
26
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_coalesced_native.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor _coalesced(const at::Tensor & self, bool coalesced);
20
+ TORCH_API at::Tensor & _coalesced_out(const at::Tensor & self, bool coalesced, at::Tensor & out);
21
+ TORCH_API at::Tensor & _coalesced_sparse_(at::Tensor & self, bool coalesced);
22
+ } // namespace native
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API at::Tensor _convert_weight_to_int4pack(const at::Tensor & self, int64_t innerKTiles);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _cudnn_rnn_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional<at::Tensor> & weight_buf, const at::Tensor & hx, const ::std::optional<at::Tensor> & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional<at::Tensor> & dropout_state);
21
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _cudnn_rnn_outf(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional<at::Tensor> & weight_buf, const at::Tensor & hx, const ::std::optional<at::Tensor> & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional<at::Tensor> & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4);
22
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _cudnn_rnn_symint_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional<at::Tensor> & weight_buf, const at::Tensor & hx, const ::std::optional<at::Tensor> & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional<at::Tensor> & dropout_state);
23
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _cudnn_rnn_symint_outf(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional<at::Tensor> & weight_buf, const at::Tensor & hx, const ::std::optional<at::Tensor> & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional<at::Tensor> & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4);
24
+
25
+ } // namespace compositeexplicitautograd
26
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimI.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_dimI_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+
26
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _embedding_bag_per_sample_weights_backward_out(at::Tensor & out, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1);
21
+ TORCH_API at::Tensor & _embedding_bag_per_sample_weights_backward_outf(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_embedding_bag_sparse_backward.h ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_embedding_bag_sparse_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
26
+ inline at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1) {
27
+ return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
28
+ }
29
+ namespace symint {
30
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
31
+ at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1) {
32
+ return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
33
+ }
34
+ }
35
+
36
+ // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
37
+ inline at::Tensor _embedding_bag_sparse_backward_symint(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1) {
38
+ return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
39
+ }
40
+ namespace symint {
41
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
42
+ at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1) {
43
+ return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
44
+ }
45
+ }
46
+
47
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_max.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_foreach_max_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_foreach_max(Tensor[] self) -> Tensor[]
26
+ inline ::std::vector<at::Tensor> _foreach_max(at::TensorList self) {
27
+ return at::_ops::_foreach_max::call(self);
28
+ }
29
+
30
+ // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
31
+ inline void _foreach_max_out(at::TensorList out, at::TensorList self) {
32
+ return at::_ops::_foreach_max_out::call(self, out);
33
+ }
34
+ // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
35
+ inline void _foreach_max_outf(at::TensorList self, at::TensorList out) {
36
+ return at::_ops::_foreach_max_out::call(self, out);
37
+ }
38
+
39
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round_cuda_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::vector<at::Tensor> _foreach_round(at::TensorList self);
21
+ TORCH_API void _foreach_round_(at::TensorList self);
22
+
23
+ } // namespace cuda
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_lu_with_info.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_lu_with_info_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
26
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _lu_with_info(const at::Tensor & self, bool pivot=true, bool check_errors=true) {
27
+ return at::_ops::_lu_with_info::call(self, pivot, check_errors);
28
+ }
29
+
30
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _masked_scale_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, double scale);
21
+ TORCH_API at::Tensor & _masked_scale_outf(const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_native.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor & _masked_softmax_out(const at::Tensor & self, const at::Tensor & mask, ::std::optional<int64_t> dim, ::std::optional<int64_t> mask_type, at::Tensor & out);
20
+ TORCH_API at::Tensor masked_softmax_cpu(const at::Tensor & self, const at::Tensor & mask, ::std::optional<int64_t> dim=::std::nullopt, ::std::optional<int64_t> mask_type=::std::nullopt);
21
+ TORCH_API at::Tensor masked_softmax_cuda(const at::Tensor & self, const at::Tensor & mask, ::std::optional<int64_t> dim=::std::nullopt, ::std::optional<int64_t> mask_type=::std::nullopt);
22
+ } // namespace native
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _nested_from_padded_and_nested_example_out(at::Tensor & out, const at::Tensor & padded, const at::Tensor & nt_example);
21
+ TORCH_API at::Tensor & _nested_from_padded_and_nested_example_outf(const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_get_jagged_dummy_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _nested_get_jagged_dummy {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_nested_get_jagged_dummy")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_nested_get_jagged_dummy(Tensor any) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & any);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & any);
26
+ };
27
+
28
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_pack_padded_sequence_backward_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor _pack_padded_sequence_backward_symint(const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first);
20
+ } // namespace native
21
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_coo_tensor_with_dims_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _sparse_coo_tensor_with_dims {
18
+ using schema = at::Tensor (int64_t, int64_t, at::IntArrayRef, ::std::optional<at::ScalarType>, ::std::optional<at::Layout>, ::std::optional<at::Device>, ::std::optional<bool>);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_coo_tensor_with_dims")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor")
24
+ static at::Tensor call(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory);
26
+ };
27
+
28
+ struct TORCH_API _sparse_coo_tensor_with_dims_out {
29
+ using schema = at::Tensor & (int64_t, int64_t, at::IntArrayRef, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_coo_tensor_with_dims")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_semi_structured_mm.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_sparse_semi_structured_mm_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor
26
+ inline at::Tensor _sparse_semi_structured_mm(const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional<at::ScalarType> out_dtype=::std::nullopt) {
27
+ return at::_ops::_sparse_semi_structured_mm::call(mat1, mat1_meta, mat2, out_dtype);
28
+ }
29
+
30
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor _standard_gamma(const at::Tensor & self, ::std::optional<at::Generator> generator=::std::nullopt);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_check_tensor_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _test_check_tensor {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_check_tensor")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_check_tensor(Tensor self) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API void _validate_sparse_bsr_tensor_args(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size);
21
+
22
+ } // namespace compositeimplicitautograd
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor _values_copy(const at::Tensor & self);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/add_ops.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API add_Tensor {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
26
+ };
27
+
28
+ struct TORCH_API add__Tensor {
29
+ using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Scalar &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add_")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)")
35
+ static at::Tensor & call(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
37
+ };
38
+
39
+ struct TORCH_API add_out {
40
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Scalar &, at::Tensor &);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)")
46
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out);
47
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out);
48
+ };
49
+
50
+ struct TORCH_API add_Scalar {
51
+ using schema = at::Tensor (const at::Tensor &, const at::Scalar &, const at::Scalar &);
52
+ using ptr_schema = schema*;
53
+ // See Note [static constexpr char* members for windows NVCC]
54
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
55
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
56
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor")
57
+ static at::Tensor call(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
58
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
59
+ };
60
+
61
+ struct TORCH_API add__Scalar {
62
+ using schema = at::Tensor & (at::Tensor &, const at::Scalar &, const at::Scalar &);
63
+ using ptr_schema = schema*;
64
+ // See Note [static constexpr char* members for windows NVCC]
65
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add_")
66
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
67
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)")
68
+ static at::Tensor & call(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
69
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
70
+ };
71
+
72
+ struct TORCH_API add_Scalar_out {
73
+ using schema = at::Tensor & (const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &);
74
+ using ptr_schema = schema*;
75
+ // See Note [static constexpr char* members for windows NVCC]
76
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
77
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_out")
78
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)")
79
+ static at::Tensor & call(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out);
80
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out);
81
+ };
82
+
83
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/arccosh_ops.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API arccosh {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::arccosh")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "arccosh(Tensor self) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ struct TORCH_API arccosh_ {
29
+ using schema = at::Tensor & (at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::arccosh_")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "arccosh_(Tensor(a!) self) -> Tensor(a!)")
35
+ static at::Tensor & call(at::Tensor & self);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self);
37
+ };
38
+
39
+ struct TORCH_API arccosh_out {
40
+ using schema = at::Tensor & (const at::Tensor &, at::Tensor &);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::arccosh")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
46
+ static at::Tensor & call(const at::Tensor & self, at::Tensor & out);
47
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out);
48
+ };
49
+
50
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/as_strided_scatter_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API as_strided_scatter {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional<c10::SymInt>);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::as_strided_scatter")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset);
26
+ };
27
+
28
+ struct TORCH_API as_strided_scatter_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional<c10::SymInt>, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::as_strided_scatter")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_native.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> batch_norm_update_stats_out(const at::Tensor & input, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, double momentum, at::Tensor & out0, at::Tensor & out1);
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> batch_norm_update_stats_cpu(const at::Tensor & input, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, double momentum);
21
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> batch_norm_update_stats_cuda(const at::Tensor & input, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, double momentum);
22
+ } // namespace native
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean);
21
+ TORCH_API at::Tensor & binary_cross_entropy_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean);
22
+ TORCH_API at::Tensor & binary_cross_entropy_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction, at::Tensor & grad_input);
23
+
24
+ } // namespace cuda
25
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binomial_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor binomial(const at::Tensor & count, const at::Tensor & prob, ::std::optional<at::Generator> generator=::std::nullopt);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/clamp_native.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+ #include <ATen/ops/clamp_meta.h>
16
+
17
+ namespace at {
18
+ namespace native {
19
+ struct TORCH_API structured_clamp_out : public at::meta::structured_clamp {
20
+ void impl(const at::Tensor & self, at::OptionalScalarRef min, at::OptionalScalarRef max, const at::Tensor & out);
21
+ };
22
+ TORCH_API at::Tensor clamp_quantized_cpu(const at::Tensor & self, const ::std::optional<at::Scalar> & min=::std::nullopt, const ::std::optional<at::Scalar> & max=::std::nullopt);
23
+ struct TORCH_API structured_clamp_Tensor_out : public at::meta::structured_clamp_Tensor {
24
+ void impl(const at::Tensor & self, at::OptionalTensorRef min, at::OptionalTensorRef max, const at::Tensor & out);
25
+ };
26
+ } // namespace native
27
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/conv3d_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor conv3d_symint(const at::Tensor & input, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1);
20
+ TORCH_API at::Tensor conv3d_padding_symint(const at::Tensor & input, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::string_view padding="valid", c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1);
21
+ } // namespace native
22
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & cudnn_convolution_transpose_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32);
21
+ TORCH_API at::Tensor & cudnn_convolution_transpose_outf(const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out);
22
+ TORCH_API at::Tensor & cudnn_convolution_transpose_symint_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32);
23
+ TORCH_API at::Tensor & cudnn_convolution_transpose_symint_outf(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out);
24
+
25
+ } // namespace compositeexplicitautograd
26
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/dropout_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API at::Tensor dropout(const at::Tensor & input, double p, bool train);
21
+ TORCH_API at::Tensor & dropout_(at::Tensor & self, double p, bool train);
22
+
23
+ } // namespace compositeimplicitautograd
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/elu_backward_meta.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeMetaFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/TensorIterator.h>
13
+ #include <ATen/TensorMeta.h>
14
+ #include <tuple>
15
+ #include <vector>
16
+
17
+ namespace at {
18
+ namespace meta {
19
+
20
+ struct TORCH_API structured_elu_backward : public TensorIteratorBase {
21
+
22
+
23
+ void meta(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result);
24
+ };
25
+
26
+ } // namespace native
27
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/erf_native.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+ #include <ATen/ops/erf_meta.h>
16
+
17
+ namespace at {
18
+ namespace native {
19
+ struct TORCH_API structured_erf_out : public at::meta::structured_erf {
20
+ void impl(const at::Tensor & self, const at::Tensor & out);
21
+ };
22
+ TORCH_API at::Tensor erf_sparse(const at::Tensor & self);
23
+ TORCH_API at::Tensor & erf_sparse_out(const at::Tensor & self, at::Tensor & out);
24
+ TORCH_API at::Tensor & erf_sparse_(at::Tensor & self);
25
+ TORCH_API at::Tensor erf_sparse_csr(const at::Tensor & self);
26
+ TORCH_API at::Tensor & erf_sparse_csr_out(const at::Tensor & self, at::Tensor & out);
27
+ TORCH_API at::Tensor & erf_sparse_csr_(at::Tensor & self);
28
+ } // namespace native
29
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1.h ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/expm1_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::expm1(Tensor self) -> Tensor
26
+ inline at::Tensor expm1(const at::Tensor & self) {
27
+ return at::_ops::expm1::call(self);
28
+ }
29
+
30
+ // aten::expm1_(Tensor(a!) self) -> Tensor(a!)
31
+ inline at::Tensor & expm1_(at::Tensor & self) {
32
+ return at::_ops::expm1_::call(self);
33
+ }
34
+
35
+ // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
36
+ inline at::Tensor & expm1_out(at::Tensor & out, const at::Tensor & self) {
37
+ return at::_ops::expm1_out::call(self, out);
38
+ }
39
+ // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
40
+ inline at::Tensor & expm1_outf(const at::Tensor & self, at::Tensor & out) {
41
+ return at::_ops::expm1_out::call(self, out);
42
+ }
43
+
44
+ }