BryanW commited on
Commit
dcb4c75
·
verified ·
1 Parent(s): 76cbda0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h +7 -0
  2. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Backend.h +7 -0
  3. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h +356 -0
  4. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h +38 -0
  5. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h +54 -0
  6. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h +34 -0
  7. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h +99 -0
  8. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +30 -0
  9. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Config.h +28 -0
  10. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Device.h +7 -0
  11. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h +118 -0
  12. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DimVector.h +7 -0
  13. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h +182 -0
  14. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h +41 -0
  15. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h +171 -0
  16. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h +540 -0
  17. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +476 -0
  18. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Functions.h +1476 -0
  19. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h +20 -0
  20. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +166 -0
  21. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h +31 -0
  22. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h +188 -0
  23. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h +449 -0
  24. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h +6 -0
  25. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h +1352 -0
  26. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h +292 -0
  27. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h +208 -0
  28. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h +59 -0
  29. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h +0 -0
  30. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h +0 -0
  31. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h +21 -0
  32. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Scalar.h +8 -0
  33. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h +54 -0
  34. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h +7 -0
  35. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +26 -0
  36. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h +131 -0
  37. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Utils.h +143 -0
  38. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h +115 -0
  39. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh +8 -0
  40. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh +8 -0
  41. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h +195 -0
  42. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h +1515 -0
  43. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h +116 -0
  44. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h +205 -0
  45. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h +383 -0
  46. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h +60 -0
  47. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h +54 -0
  48. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h +319 -0
  49. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h +36 -0
  50. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h +117 -0
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/util/ArrayRef.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Backend.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/core/Backend.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/CollapseDims.h>
5
+ #include <ATen/Parallel.h>
6
+ #include <ATen/TensorUtils.h>
7
+ #include <c10/util/irange.h>
8
+ #include <cstring>
9
+ #include <limits>
10
+
11
+ namespace at {
12
+
13
+ /*
14
+ * The basic strategy for apply is as follows:
15
+ *
16
+ * 1. Starting with the outermost index, loop until we reach a dimension where
17
+ * the data is no longer contiguous, i.e. the stride at that dimension is not
18
+ * equal to the size of the tensor defined by the outer dimensions. Let's call
19
+ * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
20
+ * A is equal to the entire Tensor. Let's call the inner tensor B.
21
+ *
22
+ * 2. We loop through the indices in B, starting at its outermost dimension. For
23
+ * example, if B is a 2x2 matrix, then we do:
24
+ *
25
+ * B[0][0]
26
+ * B[0][1]
27
+ * B[1][0]
28
+ * B[1][1]
29
+ *
30
+ * We set the offset into the underlying storage as (storageOffset + stride_B *
31
+ * index_B), i.e. basically we compute the offset into the storage as we would
32
+ * normally for a Tensor. But because we are guaranteed the subsequent data is
33
+ * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
34
+ * the operation, without having to follow the order described by the strides of
35
+ * A.
36
+ *
37
+ * 3. As an optimization, we merge dimensions of A that are contiguous in
38
+ * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
39
+ * then the first two dimensions can be merged for the purposes of APPLY,
40
+ * reducing the number of nested loops.
41
+ */
42
+
43
+ inline Tensor sort_strides(Tensor& tensor_) {
44
+ IntArrayRef strides = tensor_.strides();
45
+ std::vector<int64_t> indices;
46
+ indices.reserve(tensor_.ndimension());
47
+ for (const auto i : c10::irange(tensor_.ndimension())) {
48
+ indices.push_back(i);
49
+ }
50
+ std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
51
+ return strides[i1] > strides[i2];
52
+ });
53
+ Tensor tensor = tensor_.permute(indices);
54
+ return tensor;
55
+ }
56
+
57
+ template <typename T, int N>
58
+ struct strided_tensor_iter_fixed {
59
+ public:
60
+ T* data_ = NULL;
61
+ int64_t dim_ = 0;
62
+
63
+ // NOLINTNEXTLINE(*array*)
64
+ int64_t counter_[N] = {0};
65
+ // NOLINTNEXTLINE(*array*)
66
+ int64_t sizes_[N] = {0};
67
+ // NOLINTNEXTLINE(*array*)
68
+ int64_t strides_[N] = {0};
69
+
70
+ strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
71
+ strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) =
72
+ delete;
73
+ strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default;
74
+ strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept =
75
+ default;
76
+ ~strided_tensor_iter_fixed() noexcept = default;
77
+ strided_tensor_iter_fixed(
78
+ Tensor& tensor,
79
+ [[maybe_unused]] bool sort_strides = false)
80
+ : data_(tensor.data_ptr<T>()) {
81
+ std::memset(counter_, 0, sizeof(int64_t) * N);
82
+ if (tensor.dim() > 0) {
83
+ std::memcpy(
84
+ sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
85
+ std::memcpy(
86
+ strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
87
+ }
88
+ dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
89
+ }
90
+ };
91
+
92
+ template <typename T>
93
+ struct strided_tensor_iter {
94
+ private:
95
+ public:
96
+ T* data_ = NULL;
97
+ int64_t dim_;
98
+
99
+ std::vector<int64_t> counter_;
100
+ std::vector<int64_t> sizes_;
101
+ std::vector<int64_t> strides_;
102
+
103
+ strided_tensor_iter(strided_tensor_iter const&) = delete;
104
+ strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete;
105
+ strided_tensor_iter(strided_tensor_iter&&) noexcept = default;
106
+ strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default;
107
+ ~strided_tensor_iter() noexcept = default;
108
+ strided_tensor_iter(Tensor& tensor)
109
+ : data_(tensor.data_ptr<T>()),
110
+ dim_(tensor.ndimension()),
111
+ counter_(dim_, 0),
112
+ sizes_(tensor.sizes().vec()),
113
+ strides_(tensor.strides().vec()) {
114
+ dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
115
+ }
116
+ };
117
+
118
+ inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
119
+ if (tensors.empty())
120
+ return true;
121
+ int64_t all_numel = tensors[0].numel();
122
+ for (const auto i : c10::irange(1, tensors.size())) {
123
+ if (tensors[i].numel() != all_numel)
124
+ return false;
125
+ }
126
+ return true;
127
+ }
128
+
129
+ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
130
+ std::ostringstream oss;
131
+ oss << "inconsistent tensor size, expected ";
132
+ for (size_t i = 0; i < tensors.size() - 1; i++) {
133
+ oss << tensors[i].sizes() << ", ";
134
+ }
135
+ oss << "and " << tensors[tensors.size() - 1].sizes()
136
+ << " to have the same number of elements, but got ";
137
+ for (size_t i = 0; i < tensors.size() - 1; i++) {
138
+ oss << tensors[i].numel() << ", ";
139
+ }
140
+ oss << "and " << tensors[tensors.size() - 1].numel()
141
+ << " elements respectively";
142
+ return oss.str();
143
+ }
144
+
145
+ inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
146
+ checkDeviceType("CPU_tensor_apply", tensors, kCPU);
147
+ checkLayout("CPU_tensor_apply", tensors, kStrided);
148
+ TORCH_CHECK(_all_equal_numel(tensors), _all_equal_numel_error(tensors));
149
+ // An empty tensor has no elements
150
+ for (auto& t : tensors)
151
+ if (t.numel() == 0)
152
+ return false;
153
+ return true;
154
+ }
155
+
156
+ inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
157
+ int64_t dim = 0;
158
+ for (auto& t : tensors)
159
+ dim = std::max(dim, t.ndimension());
160
+ return dim;
161
+ }
162
+
163
+ inline void iterate(int64_t /*size*/) {}
164
+
165
+ template <typename Arg, typename... Args>
166
+ inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
167
+ iter.counter_[iter.dim_ - 1] += size;
168
+ iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
169
+ iterate(size, iter_tail...);
170
+ }
171
+
172
+ inline bool iterate_continue() {
173
+ return true;
174
+ }
175
+
176
+ template <typename Arg, typename... Args>
177
+ inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
178
+ return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
179
+ iterate_continue(iter_tail...);
180
+ }
181
+
182
+ inline int64_t max_iterate_size() {
183
+ return std::numeric_limits<int64_t>::max();
184
+ }
185
+
186
+ template <typename Arg, typename... Args>
187
+ inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
188
+ return std::min(
189
+ (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
190
+ max_iterate_size(iter_tail...));
191
+ }
192
+
193
+ inline void iterate_overflow() {}
194
+
195
+ template <typename Arg, typename... Args>
196
+ inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
197
+ if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
198
+ for (int64_t i = iter.dim_ - 1; i > 0; i--) {
199
+ if (iter.counter_[i] == iter.sizes_[i]) {
200
+ iter.counter_[i] = 0;
201
+ iter.counter_[i - 1]++;
202
+ iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
203
+ iter.strides_[i - 1];
204
+ }
205
+ }
206
+ }
207
+ iterate_overflow(iter_tail...);
208
+ }
209
+
210
+ inline void forward(int64_t /*offset*/) {}
211
+
212
+ template <typename Arg, typename... Args>
213
+ inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
214
+ int64_t multi = offset;
215
+ for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
216
+ int64_t inc = multi % iter.sizes_[i];
217
+ multi = multi / iter.sizes_[i];
218
+ iter.data_ = iter.data_ + inc * iter.strides_[i];
219
+ iter.counter_[i] += inc;
220
+ }
221
+ forward(offset, iter_tail...);
222
+ }
223
+
224
+ inline int64_t max_dim() {
225
+ return 0;
226
+ }
227
+
228
+ template <typename Arg, typename... Args>
229
+ inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
230
+ return std::max(iter.dim_, max_dim(iter_tail...));
231
+ }
232
+
233
+ inline void apply_op() {}
234
+
235
+ template <typename Op, typename... Args>
236
+ inline void apply_op(
237
+ int64_t numel,
238
+ int64_t offset,
239
+ const Op& op,
240
+ Args... iters) {
241
+ // For 0-dim tensors
242
+ if (numel == 1 && max_dim(iters...) == 0) {
243
+ op(*iters.data_...);
244
+ return;
245
+ }
246
+ if (offset > 0)
247
+ forward(offset, iters...);
248
+ // Splitting this into chunks helps the compiler create faster assembly
249
+ for (int64_t i = 0; i < numel;) {
250
+ for (; iterate_continue(iters...) && i < numel;) {
251
+ op(*iters.data_...);
252
+ iterate(1, iters...);
253
+ i++;
254
+ }
255
+ iterate_overflow(iters...);
256
+ }
257
+ }
258
+
259
+ /*
260
+ Apply a pointwise operator to sequence of tensors
261
+
262
+ The calling convention for op is a function/functor that takes the same
263
+ number of pointers of type scalar as the number of given tensors. For example,
264
+ to compute a = b * c, op would be of the form:
265
+ [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
266
+ b_val[0] * c_val[0]; };
267
+ */
268
+
269
+ template <typename scalar1, typename scalar2, typename Op>
270
+ inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
271
+ if (!_apply_preamble({tensor1, tensor2}))
272
+ return;
273
+ if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
274
+ apply_op(
275
+ tensor1.numel(),
276
+ 0,
277
+ op,
278
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
279
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2));
280
+ } else {
281
+ apply_op(
282
+ tensor1.numel(),
283
+ 0,
284
+ op,
285
+ strided_tensor_iter<scalar1>(tensor1),
286
+ strided_tensor_iter<scalar2>(tensor2));
287
+ }
288
+ }
289
+
290
+ template <typename scalar1, typename scalar2, typename scalar3, typename Op>
291
+ inline void CPU_tensor_apply3(
292
+ Tensor tensor1,
293
+ Tensor tensor2,
294
+ Tensor tensor3,
295
+ const Op op) {
296
+ if (!_apply_preamble({tensor1, tensor2, tensor3}))
297
+ return;
298
+ if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
299
+ apply_op(
300
+ tensor1.numel(),
301
+ 0,
302
+ op,
303
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
304
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2),
305
+ strided_tensor_iter_fixed<scalar3, 8>(tensor3));
306
+ } else {
307
+ apply_op(
308
+ tensor1.numel(),
309
+ 0,
310
+ op,
311
+ strided_tensor_iter<scalar1>(tensor1),
312
+ strided_tensor_iter<scalar2>(tensor2),
313
+ strided_tensor_iter<scalar3>(tensor3));
314
+ }
315
+ }
316
+
317
+ template <
318
+ typename scalar1,
319
+ typename scalar2,
320
+ typename scalar3,
321
+ typename scalar4,
322
+ typename Op>
323
+ inline void CPU_tensor_apply4(
324
+ Tensor tensor1,
325
+ Tensor tensor2,
326
+ Tensor tensor3,
327
+ Tensor tensor4,
328
+ const Op op) {
329
+ if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
330
+ return;
331
+ if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
332
+ apply_op(
333
+ tensor1.numel(),
334
+ 0,
335
+ op,
336
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
337
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2),
338
+ strided_tensor_iter_fixed<scalar3, 8>(tensor3),
339
+ strided_tensor_iter_fixed<scalar4, 8>(tensor4));
340
+ } else {
341
+ apply_op(
342
+ tensor1.numel(),
343
+ 0,
344
+ op,
345
+ strided_tensor_iter<scalar1>(tensor1),
346
+ strided_tensor_iter<scalar2>(tensor2),
347
+ strided_tensor_iter<scalar3>(tensor3),
348
+ strided_tensor_iter<scalar4>(tensor4));
349
+ }
350
+ }
351
+
352
+ } // namespace at
353
+
354
+ #else
355
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
356
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/Allocator.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ // This file creates a fake allocator that just throws exceptions if
8
+ // it is actually used.
9
+
10
+ // state passed to the allocator is the std::function<void(void*)> called
11
+ // when the blob is release by ATen
12
+
13
+ namespace at {
14
+
15
+ static void* cpu_fixed_malloc(void*, ptrdiff_t) {
16
+ TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
17
+ }
18
+
19
+ static void* cpu_fixed_realloc(void*, void*, ptrdiff_t) {
20
+ TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
21
+ }
22
+
23
+ static void cpu_fixed_free(void* state, void* allocation) {
24
+ auto on_release = static_cast<std::function<void(void*)>*>(state);
25
+ (*on_release)(allocation);
26
+ delete on_release;
27
+ }
28
+
29
+ static Allocator CPU_fixed_allocator = {
30
+ cpu_fixed_malloc,
31
+ cpu_fixed_realloc,
32
+ cpu_fixed_free};
33
+
34
+ } // namespace at
35
+
36
+ #else
37
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
38
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Generator.h>
5
+ #include <ATen/core/MT19937RNGEngine.h>
6
+ #include <c10/core/GeneratorImpl.h>
7
+ #include <optional>
8
+
9
+ namespace at {
10
+
11
+ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
12
+ // Constructors
13
+ CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
14
+ ~CPUGeneratorImpl() override = default;
15
+
16
+ // CPUGeneratorImpl methods
17
+ std::shared_ptr<CPUGeneratorImpl> clone() const;
18
+ void set_current_seed(uint64_t seed) override;
19
+ void set_offset(uint64_t offset) override;
20
+ uint64_t get_offset() const override;
21
+ uint64_t current_seed() const override;
22
+ uint64_t seed() override;
23
+ void set_state(const c10::TensorImpl& new_state) override;
24
+ c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
25
+ static c10::DeviceType device_type();
26
+ uint32_t random();
27
+ uint64_t random64();
28
+ std::optional<float> next_float_normal_sample();
29
+ std::optional<double> next_double_normal_sample();
30
+ void set_next_float_normal_sample(std::optional<float> randn);
31
+ void set_next_double_normal_sample(std::optional<double> randn);
32
+ at::mt19937 engine();
33
+ void set_engine(at::mt19937 engine);
34
+
35
+ private:
36
+ CPUGeneratorImpl* clone_impl() const override;
37
+ at::mt19937 engine_;
38
+ std::optional<float> next_float_normal_sample_;
39
+ std::optional<double> next_double_normal_sample_;
40
+ };
41
+
42
+ namespace detail {
43
+
44
+ TORCH_API const Generator& getDefaultCPUGenerator();
45
+ TORCH_API Generator
46
+ createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
47
+
48
+ } // namespace detail
49
+
50
+ } // namespace at
51
+
52
+ #else
53
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
54
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <ATen/core/TensorBody.h>
3
+
4
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
5
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
6
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
7
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
8
+ //
9
+ // Note [Avoiding Include Cycles In Static Dispatch]
10
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
11
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
12
+ //
13
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
14
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
15
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
16
+ // directly inlined into TensorBody.h.
17
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
18
+ // which include functions that have defaultable std::optional<Tensor> arguments.
19
+ // That requires knowing the full Tensor class definition.
20
+ //
21
+ // We break the cycle by doing the following:
22
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
23
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
24
+ // - CPUFunctions_inl.h includes everything else
25
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
26
+ // and then it includes CPUFunctions_inl.h.
27
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
28
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
29
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
30
+ #include <ATen/CUDAFunctions_inl.h>
31
+
32
+ #else
33
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
34
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <c10/util/Exception.h>
3
+ #include <utility>
4
+
5
+ namespace at {
6
+
7
+ /*
8
+ [collapse dims] Updates sizes, and strides to reflect a "collapse" of
9
+ the info, possibly excluding the optional excludeDim. A "collapsed" version
10
+ of the info is the fewest dims that order the tensor's elements in the same
11
+ way as the original info. If excludeDim is specified, the collapse is the
12
+ fewest dims that order the tensor's elements as the original and preserve the
13
+ excluded dimension, unless the tensor collapses to a point.
14
+
15
+ This function returns a pair of values.
16
+
17
+ 1) The (new) index of the preserved dimension if excludeDim is
18
+ specified. 0 if the tensor is collapsed to a point. -1
19
+ otherwise.
20
+
21
+ 2) The new number of dimensions.
22
+ */
23
+ template <typename T>
24
+ inline std::pair<int64_t, int64_t> collapse_dims(
25
+ T* sizes,
26
+ T* strides,
27
+ int64_t dims,
28
+ const int excludeDim = -1) {
29
+ TORCH_CHECK(
30
+ excludeDim >= -1 && excludeDim < dims,
31
+ "expected excluded dim between -1 and dims - 1");
32
+
33
+ int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
34
+ int64_t newIndex = -1;
35
+ int64_t oldIndex = 0;
36
+ int64_t remappedExcludedDim = -1;
37
+
38
+ while (oldIndex < dims) {
39
+ // Finds a dimension to collapse into
40
+ for (; oldIndex < stopDim; ++oldIndex) {
41
+ if (sizes[oldIndex] == 1) {
42
+ continue;
43
+ }
44
+
45
+ ++newIndex;
46
+ sizes[newIndex] = sizes[oldIndex];
47
+ strides[newIndex] = strides[oldIndex];
48
+ ++oldIndex;
49
+ break;
50
+ }
51
+
52
+ // Collapses dims
53
+ for (; oldIndex < stopDim; ++oldIndex) {
54
+ if (sizes[oldIndex] == 1) {
55
+ continue;
56
+ }
57
+
58
+ if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
59
+ sizes[newIndex] *= sizes[oldIndex];
60
+ strides[newIndex] = strides[oldIndex];
61
+ } else {
62
+ ++newIndex;
63
+ sizes[newIndex] = sizes[oldIndex];
64
+ strides[newIndex] = strides[oldIndex];
65
+ }
66
+ }
67
+
68
+ // Handles excludeDim being set (oldIndex == excludeDim)
69
+ if (oldIndex != dims) {
70
+ // Preserves excluded dimension
71
+ ++newIndex;
72
+ sizes[newIndex] = sizes[oldIndex];
73
+ strides[newIndex] = strides[oldIndex];
74
+ remappedExcludedDim = newIndex;
75
+
76
+ // Restarts iteration after excludeDim
77
+ ++oldIndex;
78
+ stopDim = dims;
79
+ }
80
+ }
81
+
82
+ // Handles special case of all dims size 1
83
+ if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
84
+ dims = 1;
85
+ sizes[0] = 1;
86
+ strides[0] = 1;
87
+
88
+ return std::pair<int64_t, int64_t>(0, 1);
89
+ }
90
+
91
+ dims = newIndex + 1;
92
+ return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
93
+ }
94
+
95
+ } // namespace at
96
+
97
+ #else
98
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
99
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
4
+
5
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
6
+
7
+ // The only #includes we need are for custom classes that have defaults in the C++ API
8
+ #include <c10/core/MemoryFormat.h>
9
+ #include <c10/core/Scalar.h>
10
+ #include <ATen/core/Reduction.h>
11
+
12
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
13
+ #error This change adds a dependency on all pytorch operators, meaning the \
14
+ file will need to be re-compiled every time an operator is changed or added. \
15
+ Consider including a specific operator from \
16
+ <ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
17
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ #include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
21
+ #include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
22
+ #include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
23
+ #include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
24
+
25
+
26
+
27
+
28
+ #else
29
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
30
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Config.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
5
+ // obvious if you forgot to include Config.h
6
+ // c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
7
+ //
8
+ // DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
9
+
10
+ #define AT_MKLDNN_ENABLED() 1
11
+ #define AT_MKLDNN_ACL_ENABLED() 0
12
+ #define AT_MKL_ENABLED() 1
13
+ #define AT_MKL_SEQUENTIAL() 0
14
+ #define AT_POCKETFFT_ENABLED() 0
15
+ #define AT_NNPACK_ENABLED() 1
16
+ #define CAFFE2_STATIC_LINK_CUDA() 0
17
+ #define AT_BUILD_WITH_BLAS() 1
18
+ #define AT_BUILD_WITH_LAPACK() 1
19
+ #define AT_PARALLEL_OPENMP 1
20
+ #define AT_PARALLEL_NATIVE 0
21
+ #define AT_BLAS_F2C() 0
22
+ #define AT_BLAS_USE_CBLAS_DOT() 0
23
+ #define AT_KLEIDIAI_ENABLED() 0
24
+ #define AT_USE_EIGEN_SPARSE() 0
25
+
26
+ #else
27
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
28
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Device.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/core/Device.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/CachingDeviceAllocator.h>
5
+ #include <c10/core/DeviceCapability.h>
6
+ #include <c10/core/DeviceType.h>
7
+ #include <c10/macros/Macros.h>
8
+
9
+ #include <ATen/detail/MTIAHooksInterface.h>
10
+ #include <optional>
11
+
12
+ namespace at::accelerator {
13
+
14
+ // Note [Accelerator Concept]
15
+ // This file defines the top level Accelerator concept for PyTorch.
16
+ // A device is an accelerator per the definition here if:
17
+ // - It is mutually exclusive with all other accelerators
18
+ // - It performs asynchronous compute via a Stream/Event system
19
+ // - It provides a set of common APIs as defined by AcceleratorHooksInterface
20
+ //
21
+ // As of today, accelerator devices are (in no particular order):
22
+ // CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
23
+
24
+ // Ensures that only one accelerator is available (at
25
+ // compile time if possible) and return it.
26
+ // When checked is true, the returned optional always has a value.
27
+ TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
28
+
29
+ // Check if the given device type is an accelerator.
30
+ TORCH_API bool isAccelerator(c10::DeviceType device_type);
31
+
32
+ // Check if the given device type is an accelerator, not the excluded ones.
33
+ template <
34
+ typename... T,
35
+ typename = std::enable_if_t<(std::is_same_v<T, c10::DeviceType> && ...)>>
36
+ inline bool isAcceleratorExcluded(
37
+ c10::DeviceType device_type,
38
+ c10::DeviceType first_excluded,
39
+ T... rest_excluded) {
40
+ if constexpr (sizeof...(rest_excluded) > 0) {
41
+ return device_type != first_excluded &&
42
+ isAcceleratorExcluded(device_type, rest_excluded...);
43
+ } else {
44
+ return device_type != first_excluded && isAccelerator(device_type);
45
+ }
46
+ }
47
+
48
+ // Return the number of the device available. Note that this is *REQUIRED* to
49
+ // not raise any exception.
50
+ TORCH_API c10::DeviceIndex deviceCount();
51
+
52
+ // Set the current device index to the given device index.
53
+ TORCH_API void setDeviceIndex(c10::DeviceIndex device_index);
54
+
55
+ // Get the current device index.
56
+ TORCH_API c10::DeviceIndex getDeviceIndex();
57
+
58
+ // Set the current stream to a given stream. Note that this API doesn't change
59
+ // the current device index.
60
+ TORCH_API void setCurrentStream(c10::Stream stream);
61
+
62
+ // Get the current stream of the given device index.
63
+ TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
64
+
65
+ // Wait (by blocking the calling thread) until all the work previously enqueued
66
+ // on the given device index has been completed.
67
+ TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
68
+
69
+ // Set the current device index to the given device_index and return the
70
+ // original device index that was active before the change.
71
+ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
72
+
73
+ // Set the current device index to the given device_index. Avoid creating a new
74
+ // context if the context for device_index is not initialized. Return the
75
+ // original device index that was active before the change.
76
+ TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
77
+
78
+ // Get the device capability of the given device index.
79
+ TORCH_API c10::DeviceCapability getDeviceCapability(
80
+ c10::DeviceIndex device_index);
81
+
82
+ TORCH_API inline void emptyCache() {
83
+ const auto device_type = getAccelerator(true).value();
84
+ at::getDeviceAllocator(device_type)->emptyCache();
85
+ }
86
+
87
+ TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats(
88
+ c10::DeviceIndex device_index) {
89
+ const auto device_type = getAccelerator(true).value();
90
+ return at::getDeviceAllocator(device_type)->getDeviceStats(device_index);
91
+ }
92
+
93
+ TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) {
94
+ const auto device_type = getAccelerator(true).value();
95
+ at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index);
96
+ }
97
+
98
+ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
99
+ const auto device_type = getAccelerator(true).value();
100
+ at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
101
+ }
102
+
103
+ TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
104
+ c10::DeviceIndex device_index) {
105
+ const auto device_type = getAccelerator(true).value();
106
+ return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
107
+ }
108
+ } // namespace at::accelerator
109
+
110
+ namespace at {
111
+ // Keep BC only
112
+ using at::accelerator::getAccelerator;
113
+ using at::accelerator::isAccelerator;
114
+ } // namespace at
115
+
116
+ #else
117
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
118
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DimVector.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/core/DimVector.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <torch/headeronly/core/Dispatch_v2.h>
5
+
6
+ // Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
7
+ #include <ATen/Dispatch.h>
8
+
9
+ // This is a new implementation of the AT_DISPATCH macro family from
10
+ // ATen/Dispatch.h
11
+ //
12
+ // The intended usage is:
13
+ //
14
+ // ScalarType scalar_type;
15
+ //
16
+ // AT_DISPATCH_V2(
17
+ // scalar_type,
18
+ // "debug string",
19
+ // AT_WRAP([&] {
20
+ // ... code to specialize with scalar_t ...
21
+ // }),
22
+ // kHalf,
23
+ // AT_EXPAND(AT_ALL_TYPES),
24
+ // ... as many types arguments as needed ...
25
+ // )
26
+ //
27
+ // For example, given an old style:
28
+ //
29
+ // AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
30
+ // kComplexHalf,
31
+ // kHalf,
32
+ // self.scalar_type(),
33
+ // "_local_scalar_dense_cpu",
34
+ // [&] {
35
+ // scalar_t value = *self.data_ptr<scalar_t>();
36
+ // r = Scalar(value);
37
+ // }
38
+ // )
39
+ //
40
+ // You now write:
41
+ //
42
+ // AT_DISPATCH_V2(
43
+ // self.scalar_type(),
44
+ // "_local_scalar_dense_cpu",
45
+ // AT_WRAP([&] {
46
+ // scalar_t value = *self.data_ptr<scalar_t>();
47
+ // r = Scalar(value);
48
+ // }),
49
+ // AT_EXPAND(AT_ALL_TYPES),
50
+ // AT_EXPAND(AT_COMPLEX_TYPES),
51
+ // kComplexHalf,
52
+ // kHalf,
53
+ // )
54
+ //
55
+ // Notably, it sports the following improvements:
56
+ //
57
+ // - It is not necessary to specify the arity (e.g.,
58
+ // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
59
+ // when using the macro
60
+ //
61
+ // - It is not necessary to specify each dtype individually; if
62
+ // there is a set of related dtypes and you want to dispatch
63
+ // over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
64
+ // in your argument list.
65
+ //
66
+ // However, you must remember to wrap the payload body in AT_WRAP, or commas
67
+ // inside your lambda will be improperly handled. Furthermore, if you more
68
+ // entries to ScalarType than can be supported by this macro, it will fail
69
+ // with an obscure error (due to attempting to concatenate AT_AP with
70
+ // something that is not a number).
71
+ //
72
+ // The implementation strategy is to use the count arguments trick
73
+ // (e.g., as described in https://stackoverflow.com/a/2124385/23845)
74
+ // to discover how many dtypes have been passed, and then dispatch to a
75
+ // hand-written macro for each arity that applies as many DISPATCH_CASE as
76
+ // necessary. The hand-written macros can be regenerated for other arities
77
+ // with the script below.
78
+ //
79
+ // There is some delicacy in the implementation in controlling when
80
+ // macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
81
+ // relied on GPT4 to help me get it right.
82
+
83
+ // See documentation above
84
+ #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
85
+ THO_DISPATCH_V2_TMPL( \
86
+ AT_DISPATCH_SWITCH, \
87
+ AT_DISPATCH_CASE, \
88
+ TYPE, \
89
+ NAME, \
90
+ AT_WRAP(BODY), \
91
+ __VA_ARGS__)
92
+
93
+ // Unused helper macros, kept for BC:
94
+ #define AT_AP_VAR(N, T, ...) \
95
+ AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
96
+
97
+ // Ensure we never have too many scalar types for the expansion here to
98
+ // support. To bump this, you must regenerate the macros below.
99
+ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
100
+
101
+ // Python code to regenerate generate code below:
102
+ #if 0
103
+
104
+ num_args = 60
105
+
106
+ for i in range(1, num_args+1):
107
+ args = ', '.join(f'_{i}' for i in range(1, i+1))
108
+ cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
109
+ print(f'#define AT_AP{i}(N, {args}) {cases}')
110
+
111
+ #endif
112
+
113
+ // Begin generated code
114
+ // clang-format off
115
+
116
+ #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
117
+ #define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
118
+ #define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
119
+ #define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
120
+ #define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
121
+ #define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
122
+ #define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
123
+ #define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
124
+ #define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
125
+ #define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
126
+ #define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
127
+ #define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
128
+ #define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
129
+ #define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
130
+ #define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
131
+ #define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
132
+ #define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
133
+ #define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
134
+ #define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
135
+ #define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
136
+ #define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
137
+ #define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
138
+ #define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
139
+ #define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
140
+ #define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
141
+ #define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
142
+ #define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
143
+ #define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
144
+ #define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
145
+ #define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
146
+ #define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
147
+ #define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
148
+ #define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
149
+ #define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
150
+ #define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
151
+ #define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
152
+ #define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
153
+ #define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
154
+ #define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
155
+ #define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
156
+ #define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
157
+ #define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
158
+ #define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
159
+ #define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
160
+ #define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
161
+ #define AT_AP46(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N)
162
+ #define AT_AP47(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N)
163
+ #define AT_AP48(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N)
164
+ #define AT_AP49(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N)
165
+ #define AT_AP50(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N)
166
+ #define AT_AP51(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N)
167
+ #define AT_AP52(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N)
168
+ #define AT_AP53(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N)
169
+ #define AT_AP54(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N)
170
+ #define AT_AP55(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N)
171
+ #define AT_AP56(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N)
172
+ #define AT_AP57(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N)
173
+ #define AT_AP58(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N)
174
+ #define AT_AP59(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N)
175
+ #define AT_AP60(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) AT_DISPATCH_CASE(_60, N)
176
+
177
+ // End generated code
178
+ // clang-format on
179
+
180
+ #else
181
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
182
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Utils.h>
5
+ #include <c10/macros/Export.h>
6
+ #include <c10/util/Exception.h>
7
+
8
+ namespace c10 {
9
+
10
+ class DynamicLibraryError : public Error {
11
+ using Error::Error;
12
+ };
13
+
14
+ } // namespace c10
15
+
16
+ namespace at {
17
+
18
+ struct DynamicLibrary {
19
+ AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
20
+ DynamicLibrary(DynamicLibrary&& other) = delete;
21
+ DynamicLibrary& operator=(DynamicLibrary&&) = delete;
22
+
23
+ TORCH_API DynamicLibrary(
24
+ const char* name,
25
+ const char* alt_name = nullptr,
26
+ bool leak_handle = false);
27
+
28
+ TORCH_API void* sym(const char* name);
29
+
30
+ TORCH_API ~DynamicLibrary();
31
+
32
+ private:
33
+ bool leak_handle;
34
+ void* handle = nullptr;
35
+ };
36
+
37
+ } // namespace at
38
+
39
+ #else
40
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
41
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/core/TensorBase.h>
4
+
5
+ namespace at::detail {
6
+
7
+ inline void check_size_nonnegative(ArrayRef<int64_t> size) {
8
+ for (const auto& x : size) {
9
+ TORCH_CHECK(
10
+ x >= 0,
11
+ "Trying to create tensor with negative dimension ",
12
+ x,
13
+ ": ",
14
+ size);
15
+ }
16
+ }
17
+
18
+ inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
19
+ for (const auto& x : size) {
20
+ TORCH_SYM_CHECK(
21
+ x.sym_ge(0),
22
+ "Trying to create tensor with negative dimension ",
23
+ x,
24
+ ": ",
25
+ size);
26
+ }
27
+ }
28
+
29
+ TORCH_API size_t computeStorageNbytesContiguous(
30
+ IntArrayRef sizes,
31
+ size_t itemsize,
32
+ size_t storage_offset = 0);
33
+ TORCH_API SymInt computeStorageNbytesContiguous(
34
+ SymIntArrayRef sizes,
35
+ const SymInt& itemsize,
36
+ const SymInt& storage_offset = 0);
37
+ TORCH_API size_t computeStorageNbytes(
38
+ IntArrayRef sizes,
39
+ IntArrayRef strides,
40
+ size_t itemsize,
41
+ size_t storage_offset = 0);
42
+ TORCH_API SymInt computeStorageNbytes(
43
+ SymIntArrayRef sizes,
44
+ SymIntArrayRef strides,
45
+ const SymInt& itemsize,
46
+ const SymInt& storage_offset = 0);
47
+
48
+ TORCH_API TensorBase empty_generic(
49
+ IntArrayRef size,
50
+ c10::Allocator* allocator,
51
+ c10::DispatchKeySet ks,
52
+ ScalarType scalar_type,
53
+ std::optional<c10::MemoryFormat> memory_format_opt);
54
+
55
+ TORCH_API TensorBase empty_generic_symint(
56
+ SymIntArrayRef size,
57
+ c10::Allocator* allocator,
58
+ c10::DispatchKeySet ks,
59
+ ScalarType scalar_type,
60
+ std::optional<c10::MemoryFormat> memory_format_opt);
61
+
62
+ TORCH_API TensorBase empty_strided_generic(
63
+ IntArrayRef size,
64
+ IntArrayRef stride,
65
+ c10::Allocator* allocator,
66
+ c10::DispatchKeySet ks,
67
+ ScalarType scalar_type);
68
+
69
+ TORCH_API TensorBase empty_strided_symint_generic(
70
+ SymIntArrayRef size,
71
+ SymIntArrayRef stride,
72
+ c10::Allocator* allocator,
73
+ c10::DispatchKeySet ks,
74
+ ScalarType scalar_type);
75
+
76
+ TORCH_API TensorBase empty_cpu(
77
+ IntArrayRef size,
78
+ ScalarType dtype,
79
+ bool pin_memory = false,
80
+ std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
81
+
82
+ TORCH_API TensorBase empty_cpu(
83
+ IntArrayRef size,
84
+ std::optional<ScalarType> dtype_opt,
85
+ std::optional<Layout> layout_opt,
86
+ std::optional<Device> device_opt,
87
+ std::optional<bool> pin_memory_opt,
88
+ std::optional<c10::MemoryFormat> memory_format_opt);
89
+
90
+ TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
91
+
92
+ TORCH_API TensorBase empty_strided_cpu(
93
+ IntArrayRef size,
94
+ IntArrayRef stride,
95
+ ScalarType dtype,
96
+ bool pin_memory = false);
97
+
98
+ TORCH_API TensorBase empty_strided_cpu(
99
+ IntArrayRef size,
100
+ IntArrayRef stride,
101
+ std::optional<ScalarType> dtype_opt,
102
+ std::optional<Layout> layout_opt,
103
+ std::optional<Device> device_opt,
104
+ std::optional<bool> pin_memory_opt);
105
+
106
+ TORCH_API TensorBase empty_strided_cpu(
107
+ IntArrayRef size,
108
+ IntArrayRef stride,
109
+ const TensorOptions& options);
110
+
111
+ TORCH_API TensorBase empty_meta(
112
+ IntArrayRef size,
113
+ ScalarType dtype,
114
+ std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
115
+
116
+ TORCH_API TensorBase empty_meta(
117
+ IntArrayRef size,
118
+ std::optional<ScalarType> dtype_opt,
119
+ std::optional<Layout> layout_opt,
120
+ std::optional<Device> device_opt,
121
+ std::optional<bool> pin_memory_opt,
122
+ std::optional<c10::MemoryFormat> memory_format_opt);
123
+
124
+ TORCH_API TensorBase empty_symint_meta(
125
+ SymIntArrayRef size,
126
+ std::optional<ScalarType> dtype_opt,
127
+ std::optional<Layout> layout_opt,
128
+ std::optional<Device> device_opt,
129
+ std::optional<bool> pin_memory_opt,
130
+ std::optional<c10::MemoryFormat> memory_format_opt);
131
+
132
+ TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
133
+
134
+ TORCH_API TensorBase
135
+ empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
136
+
137
+ TORCH_API TensorBase empty_strided_meta(
138
+ IntArrayRef size,
139
+ IntArrayRef stride,
140
+ std::optional<ScalarType> dtype_opt,
141
+ std::optional<Layout> layout_opt,
142
+ std::optional<Device> device_opt,
143
+ std::optional<bool> pin_memory_opt);
144
+
145
+ TORCH_API TensorBase empty_strided_meta(
146
+ IntArrayRef size,
147
+ IntArrayRef stride,
148
+ const TensorOptions& options);
149
+
150
+ TORCH_API TensorBase empty_strided_symint_meta(
151
+ SymIntArrayRef size,
152
+ SymIntArrayRef stride,
153
+ ScalarType dtype);
154
+
155
+ TORCH_API TensorBase empty_strided_symint_meta(
156
+ SymIntArrayRef size,
157
+ SymIntArrayRef stride,
158
+ std::optional<ScalarType> dtype_opt,
159
+ std::optional<Layout> layout_opt,
160
+ std::optional<Device> device_opt);
161
+
162
+ TORCH_API TensorBase empty_strided_symint_meta(
163
+ SymIntArrayRef size,
164
+ SymIntArrayRef stride,
165
+ const TensorOptions& options);
166
+
167
+ } // namespace at::detail
168
+
169
+ #else
170
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
171
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #ifndef AT_PER_OPERATOR_HEADERS
5
+ #include <ATen/Functions.h>
6
+ #else
7
+ #include <ATen/ops/view.h>
8
+ #include <ATen/ops/view_copy.h>
9
+ #endif
10
+
11
+ #include <ATen/Tensor.h>
12
+ #include <ATen/core/DimVector.h>
13
+ #include <c10/util/Exception.h>
14
+ #include <c10/util/MaybeOwned.h>
15
+ #include <c10/util/irange.h>
16
+
17
+ #include <functional>
18
+ #include <tuple>
19
+ #include <utility>
20
+
21
+ namespace at {
22
+
23
+ TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
24
+ TORCH_API std::vector<SymInt> infer_size_symint(
25
+ SymIntArrayRef a,
26
+ SymIntArrayRef b);
27
+ TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
28
+ TORCH_API SymDimVector
29
+ infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
30
+
31
+ // Named type instead of a pair/tuple so that we can be sure to
32
+ // construct the vectors in place and get NRVO.
33
+ template <typename Container>
34
+ struct InferExpandGeometryResult {
35
+ Container sizes;
36
+ Container strides;
37
+ explicit InferExpandGeometryResult(size_t ndim)
38
+ : sizes(ndim), strides(ndim) {}
39
+ explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
40
+ : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
41
+ };
42
+
43
+ TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
44
+ inferExpandGeometry(
45
+ IntArrayRef tensor_sizes,
46
+ IntArrayRef tensor_strides,
47
+ IntArrayRef sizes);
48
+
49
+ TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
50
+ IntArrayRef tensor_sizes,
51
+ IntArrayRef tensor_strides,
52
+ IntArrayRef sizes);
53
+
54
+ TORCH_API std::vector<int64_t> infer_dense_strides(
55
+ IntArrayRef tensor_sizes,
56
+ IntArrayRef tensor_strides);
57
+
58
+ // True if input shapes are expandable
59
+ // NOTE: infer_size did a similar check, please keep them sync if change is
60
+ // needed
61
+ inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
62
+ size_t ndim1 = shape1.size();
63
+ size_t ndim2 = shape2.size();
64
+ size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
65
+
66
+ for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
67
+ if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
68
+ shape2[ndim2] == 1) {
69
+ continue;
70
+ }
71
+ return false;
72
+ }
73
+ return true;
74
+ }
75
+
76
+ // avoid copy-construction of Tensor by using a reference_wrapper.
77
+ inline void check_defined(
78
+ std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
79
+ const char* api_name) {
80
+ for (auto& t : tensors) {
81
+ if (!t.get().defined()) {
82
+ TORCH_CHECK(false, api_name, "(...) called with an undefined Tensor");
83
+ }
84
+ }
85
+ }
86
+
87
+ // NOTE [ ExpandUtils Borrowing ]
88
+ //
89
+ // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
90
+ // expansion may not actually be needed, in which case we can improve
91
+ // efficiency by returning
92
+ // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
93
+ // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
94
+ // must not outlive the original `Tensor` object that `to_expand`
95
+ // referred to! The deleted rvalue reference overloads of these
96
+ // functions help with this by preventing trivial use of a temporary
97
+ // resulting from a function call, but it is still possible to make a
98
+ // mistake.
99
+
100
+ inline c10::MaybeOwned<Tensor> expand_inplace(
101
+ const Tensor& tensor,
102
+ const Tensor& to_expand) {
103
+ if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
104
+ return c10::MaybeOwned<Tensor>::borrowed(to_expand);
105
+ }
106
+ return c10::MaybeOwned<Tensor>::owned(
107
+ to_expand.expand_symint(tensor.sym_sizes()));
108
+ }
109
+
110
+ inline c10::MaybeOwned<Tensor> expand_inplace(
111
+ const Tensor& tensor,
112
+ Tensor&& to_expand) = delete;
113
+
114
+ inline c10::MaybeOwned<Tensor> expand_inplace(
115
+ const Tensor& tensor,
116
+ const Tensor& to_expand,
117
+ const char* api_name) {
118
+ check_defined({tensor, to_expand}, api_name);
119
+ return expand_inplace(tensor, to_expand);
120
+ }
121
+
122
+ inline c10::MaybeOwned<Tensor> expand_inplace(
123
+ const Tensor& tensor,
124
+ Tensor&& to_expand,
125
+ const char* api_name) = delete;
126
+
127
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
128
+ expand_inplace(
129
+ const Tensor& tensor,
130
+ const Tensor& to_expand1,
131
+ const Tensor& to_expand2) {
132
+ if (tensor.sizes().equals(to_expand1.sizes()) &&
133
+ tensor.sizes().equals((to_expand2.sizes()))) {
134
+ return std::make_tuple(
135
+ c10::MaybeOwned<Tensor>::borrowed(to_expand1),
136
+ c10::MaybeOwned<Tensor>::borrowed(to_expand2));
137
+ }
138
+
139
+ return std::make_tuple(
140
+ c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
141
+ c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
142
+ }
143
+
144
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
145
+ expand_inplace(
146
+ const Tensor& tensor,
147
+ Tensor&& to_expand1,
148
+ const Tensor& to_expand2) = delete;
149
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
150
+ expand_inplace(
151
+ const Tensor& tensor,
152
+ const Tensor& to_expand1,
153
+ Tensor&& to_expand2) = delete;
154
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
155
+ expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
156
+ delete;
157
+
158
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
159
+ expand_inplace(
160
+ const Tensor& tensor,
161
+ const Tensor& to_expand1,
162
+ const Tensor& to_expand2,
163
+ const char* api_name) {
164
+ check_defined({tensor, to_expand1, to_expand2}, api_name);
165
+ return expand_inplace(tensor, to_expand1, to_expand2);
166
+ }
167
+
168
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
169
+ expand_inplace(
170
+ const Tensor& tensor,
171
+ Tensor&& to_expand1,
172
+ const Tensor& to_expand2,
173
+ const char* api_name) = delete;
174
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
175
+ expand_inplace(
176
+ const Tensor& tensor,
177
+ const Tensor& to_expand1,
178
+ Tensor&& to_expand2,
179
+ const char* api_name) = delete;
180
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
181
+ expand_inplace(
182
+ const Tensor& tensor,
183
+ Tensor&& to_expand1,
184
+ Tensor&& to_expand2,
185
+ const char* api_name) = delete;
186
+
187
+ // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
188
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
189
+ expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
190
+ auto s1 = to_expand1.sym_sizes();
191
+ auto s2 = to_expand2.sym_sizes();
192
+ if (s1.equals(s2)) {
193
+ return std::make_tuple(
194
+ c10::MaybeOwned<Tensor>::borrowed(to_expand1),
195
+ c10::MaybeOwned<Tensor>::borrowed(to_expand2));
196
+ }
197
+
198
+ auto expanded_size = infer_size_symdimvector(s1, s2);
199
+ return std::make_tuple(
200
+ c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
201
+ c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
202
+ }
203
+
204
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
205
+ expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
206
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
207
+ expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
208
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
209
+ expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
210
+
211
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
212
+ expand_outplace(
213
+ const Tensor& to_expand1,
214
+ const Tensor& to_expand2,
215
+ const char* api_name) {
216
+ check_defined({to_expand1, to_expand2}, api_name);
217
+ return expand_outplace(to_expand1, to_expand2);
218
+ }
219
+
220
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
221
+ expand_outplace(
222
+ Tensor&& to_expand1,
223
+ const Tensor& to_expand2,
224
+ const char* api_name) = delete;
225
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
226
+ expand_outplace(
227
+ const Tensor& to_expand1,
228
+ Tensor&& to_expand2,
229
+ const char* api_name) = delete;
230
+ inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
231
+ expand_outplace(
232
+ Tensor&& to_expand1,
233
+ Tensor&& to_expand2,
234
+ const char* api_name) = delete;
235
+
236
+ inline std::tuple<
237
+ c10::MaybeOwned<Tensor>,
238
+ c10::MaybeOwned<Tensor>,
239
+ c10::MaybeOwned<Tensor>>
240
+ expand_outplace(
241
+ const Tensor& to_expand1,
242
+ const Tensor& to_expand2,
243
+ const Tensor& to_expand3) {
244
+ if (to_expand1.sizes().equals(to_expand2.sizes()) &&
245
+ to_expand1.sizes().equals(to_expand3.sizes())) {
246
+ return std::make_tuple(
247
+ c10::MaybeOwned<Tensor>::borrowed(to_expand1),
248
+ c10::MaybeOwned<Tensor>::borrowed(to_expand2),
249
+ c10::MaybeOwned<Tensor>::borrowed(to_expand3));
250
+ }
251
+
252
+ auto expanded_size12 =
253
+ infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
254
+ auto expanded_size =
255
+ infer_size_dimvector(expanded_size12, to_expand3.sizes());
256
+ return std::make_tuple(
257
+ c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
258
+ c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
259
+ c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
260
+ }
261
+
262
+ inline std::tuple<
263
+ c10::MaybeOwned<Tensor>,
264
+ c10::MaybeOwned<Tensor>,
265
+ c10::MaybeOwned<Tensor>>
266
+ expand_outplace(
267
+ Tensor&& to_expand1,
268
+ const Tensor& to_expand2,
269
+ const Tensor& to_expand3) = delete;
270
+ inline std::tuple<
271
+ c10::MaybeOwned<Tensor>,
272
+ c10::MaybeOwned<Tensor>,
273
+ c10::MaybeOwned<Tensor>>
274
+ expand_outplace(
275
+ const Tensor& to_expand1,
276
+ Tensor&& to_expand2,
277
+ const Tensor& to_expand3) = delete;
278
+ inline std::tuple<
279
+ c10::MaybeOwned<Tensor>,
280
+ c10::MaybeOwned<Tensor>,
281
+ c10::MaybeOwned<Tensor>>
282
+ expand_outplace(
283
+ Tensor&& to_expand1,
284
+ Tensor&& to_expand2,
285
+ const Tensor& to_expand3) = delete;
286
+ inline std::tuple<
287
+ c10::MaybeOwned<Tensor>,
288
+ c10::MaybeOwned<Tensor>,
289
+ c10::MaybeOwned<Tensor>>
290
+ expand_outplace(
291
+ const Tensor& to_expand1,
292
+ const Tensor& to_expand2,
293
+ Tensor&& to_expand3) = delete;
294
+ inline std::tuple<
295
+ c10::MaybeOwned<Tensor>,
296
+ c10::MaybeOwned<Tensor>,
297
+ c10::MaybeOwned<Tensor>>
298
+ expand_outplace(
299
+ Tensor&& to_expand1,
300
+ const Tensor& to_expand2,
301
+ Tensor&& to_expand3) = delete;
302
+ inline std::tuple<
303
+ c10::MaybeOwned<Tensor>,
304
+ c10::MaybeOwned<Tensor>,
305
+ c10::MaybeOwned<Tensor>>
306
+ expand_outplace(
307
+ const Tensor& to_expand1,
308
+ Tensor&& to_expand2,
309
+ Tensor&& to_expand3) = delete;
310
+ inline std::tuple<
311
+ c10::MaybeOwned<Tensor>,
312
+ c10::MaybeOwned<Tensor>,
313
+ c10::MaybeOwned<Tensor>>
314
+ expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
315
+ delete;
316
+
317
+ inline std::tuple<
318
+ c10::MaybeOwned<Tensor>,
319
+ c10::MaybeOwned<Tensor>,
320
+ c10::MaybeOwned<Tensor>>
321
+ expand_outplace(
322
+ const Tensor& to_expand1,
323
+ const Tensor& to_expand2,
324
+ const Tensor& to_expand3,
325
+ const char* api_name) {
326
+ check_defined({to_expand1, to_expand2, to_expand3}, api_name);
327
+ return expand_outplace(to_expand1, to_expand2, to_expand3);
328
+ }
329
+
330
+ inline std::tuple<
331
+ c10::MaybeOwned<Tensor>,
332
+ c10::MaybeOwned<Tensor>,
333
+ c10::MaybeOwned<Tensor>>
334
+ expand_outplace(
335
+ Tensor&& to_expand1,
336
+ const Tensor& to_expand2,
337
+ const Tensor& to_expand3,
338
+ const char* api_name) = delete;
339
+ inline std::tuple<
340
+ c10::MaybeOwned<Tensor>,
341
+ c10::MaybeOwned<Tensor>,
342
+ c10::MaybeOwned<Tensor>>
343
+ expand_outplace(
344
+ const Tensor& to_expand1,
345
+ Tensor&& to_expand2,
346
+ const Tensor& to_expand3,
347
+ const char* api_name) = delete;
348
+ inline std::tuple<
349
+ c10::MaybeOwned<Tensor>,
350
+ c10::MaybeOwned<Tensor>,
351
+ c10::MaybeOwned<Tensor>>
352
+ expand_outplace(
353
+ Tensor&& to_expand1,
354
+ Tensor&& to_expand2,
355
+ const Tensor& to_expand3,
356
+ const char* api_name) = delete;
357
+ inline std::tuple<
358
+ c10::MaybeOwned<Tensor>,
359
+ c10::MaybeOwned<Tensor>,
360
+ c10::MaybeOwned<Tensor>>
361
+ expand_outplace(
362
+ const Tensor& to_expand1,
363
+ const Tensor& to_expand2,
364
+ Tensor&& to_expand3,
365
+ const char* api_name) = delete;
366
+ inline std::tuple<
367
+ c10::MaybeOwned<Tensor>,
368
+ c10::MaybeOwned<Tensor>,
369
+ c10::MaybeOwned<Tensor>>
370
+ expand_outplace(
371
+ Tensor&& to_expand1,
372
+ const Tensor& to_expand2,
373
+ Tensor&& to_expand3,
374
+ const char* api_name) = delete;
375
+ inline std::tuple<
376
+ c10::MaybeOwned<Tensor>,
377
+ c10::MaybeOwned<Tensor>,
378
+ c10::MaybeOwned<Tensor>>
379
+ expand_outplace(
380
+ const Tensor& to_expand1,
381
+ Tensor&& to_expand2,
382
+ Tensor&& to_expand3,
383
+ const char* api_name) = delete;
384
+ inline std::tuple<
385
+ c10::MaybeOwned<Tensor>,
386
+ c10::MaybeOwned<Tensor>,
387
+ c10::MaybeOwned<Tensor>>
388
+ expand_outplace(
389
+ Tensor&& to_expand1,
390
+ Tensor&& to_expand2,
391
+ Tensor&& to_expand3,
392
+ const char* api_name) = delete;
393
+
394
+ inline c10::MaybeOwned<Tensor> expand_size(
395
+ const Tensor& to_expand,
396
+ IntArrayRef sizes) {
397
+ if (to_expand.sizes().equals(sizes)) {
398
+ return c10::MaybeOwned<Tensor>::borrowed(to_expand);
399
+ }
400
+
401
+ return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
402
+ }
403
+
404
+ inline c10::MaybeOwned<Tensor> expand_size(
405
+ Tensor&& to_expand,
406
+ IntArrayRef sizes) = delete;
407
+
408
+ inline c10::MaybeOwned<Tensor> expand_size(
409
+ const Tensor& to_expand,
410
+ IntArrayRef sizes,
411
+ const char* api_name) {
412
+ check_defined({to_expand}, api_name);
413
+ return expand_size(to_expand, sizes);
414
+ }
415
+
416
+ inline c10::MaybeOwned<Tensor> expand_size(
417
+ Tensor&& to_expand,
418
+ IntArrayRef sizes,
419
+ const char* api_name) = delete;
420
+
421
+ inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
422
+ // expands a list of Tensors; ignores undefined (null) tensors
423
+ bool first = true;
424
+ SymDimVector sizes;
425
+ for (const auto i : c10::irange(to_expand.size())) {
426
+ if (!to_expand[i].defined()) {
427
+ continue;
428
+ } else if (first) {
429
+ sizes = to_expand[i].sym_sizes();
430
+ first = false;
431
+ } else {
432
+ sizes = infer_size_symdimvector(sizes, to_expand[i].sym_sizes());
433
+ }
434
+ }
435
+
436
+ std::vector<Tensor> result(to_expand.size());
437
+ for (const auto i : c10::irange(to_expand.size())) {
438
+ if (!to_expand[i].defined()) {
439
+ continue;
440
+ } else if (to_expand[i].sym_sizes().equals(sizes)) {
441
+ result[i] = to_expand[i];
442
+ } else {
443
+ result[i] = to_expand[i].expand_symint(sizes);
444
+ }
445
+ }
446
+ return result;
447
+ }
448
+
449
+ template <typename T>
450
+ inline Tensor _sum_to(
451
+ Tensor tensor,
452
+ const c10::ArrayRef<T> shape,
453
+ bool always_return_non_view = false) {
454
+ if (shape.size() == 0) {
455
+ return tensor.sum();
456
+ }
457
+
458
+ auto sizes = at::symint::sizes<T>(tensor);
459
+ c10::SmallVector<int64_t, 8> reduce_dims;
460
+ const int64_t leading_dims = sizes.size() - shape.size();
461
+ for (const auto i : c10::irange(leading_dims)) {
462
+ reduce_dims.push_back(i);
463
+ }
464
+ for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
465
+ if (TORCH_GUARD_OR_FALSE(sym_eq(shape[i - leading_dims], 1)) &&
466
+ TORCH_GUARD_OR_TRUE(sym_ne(sizes[i], 1))) {
467
+ reduce_dims.push_back(i);
468
+ } else {
469
+ // if we assume no reduction due to unbacked we ensure that at runtime.
470
+ TORCH_MAYBE_SYM_CHECK(
471
+ sym_eq(shape[i - leading_dims], sizes[i]),
472
+ "non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:",
473
+ shape[i - leading_dims],
474
+ ", ",
475
+ sizes[i])
476
+ }
477
+ }
478
+
479
+ if (!reduce_dims.empty()) {
480
+ tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
481
+ }
482
+
483
+ if (always_return_non_view) {
484
+ // This is only actually used by the functionalization pass.
485
+ // We want to be able to guarantee that this function doesn't return a view
486
+ // of the input.
487
+ return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
488
+ : tensor.clone();
489
+ } else {
490
+ return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
491
+ }
492
+ }
493
+
494
+ inline Tensor sum_to(
495
+ Tensor tensor,
496
+ const c10::SymIntArrayRef shape,
497
+ bool always_return_non_view = false) {
498
+ return _sum_to(std::move(tensor), shape, always_return_non_view);
499
+ }
500
+
501
+ // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
502
+ // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
503
+ inline Tensor sum_to(
504
+ Tensor tensor,
505
+ const IntArrayRef shape,
506
+ bool always_return_non_view = false) {
507
+ return _sum_to(std::move(tensor), shape, always_return_non_view);
508
+ }
509
+
510
+ inline bool is_expandable_to(
511
+ SymIntArrayRef shape,
512
+ c10::SymIntArrayRef desired) {
513
+ size_t ndim = shape.size();
514
+ size_t target_dim = desired.size();
515
+ if (ndim > target_dim) {
516
+ return false;
517
+ }
518
+ for (const auto i : c10::irange(ndim)) {
519
+ const auto& size = shape[ndim - i - 1];
520
+ const auto& target = desired[target_dim - i - 1];
521
+ if (size != target && size != 1) {
522
+ return false;
523
+ }
524
+ }
525
+ return true;
526
+ }
527
+
528
+ inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
529
+ auto sym_shape = c10::SymIntArrayRef(
530
+ reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
531
+ auto sym_desired = c10::SymIntArrayRef(
532
+ reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
533
+ return is_expandable_to(sym_shape, sym_desired);
534
+ }
535
+
536
+ } // namespace at
537
+
538
+ #else
539
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
540
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/ArrayRef.h>
6
+ #include <ATen/FunctionalStorageImpl.h>
7
+ #include <ATen/core/IListRef.h>
8
+ #include <ATen/core/List.h>
9
+ #include <ATen/core/boxing/BoxedKernel.h>
10
+ #include <ATen/core/boxing/impl/boxing.h>
11
+ #include <ATen/core/dispatch/Dispatcher.h>
12
+
13
+ #include <c10/core/DispatchKey.h>
14
+
15
+ namespace at {
16
+
17
+ // Note [Functionalization Pass In Core]
18
+ // The Functionalization pass is used to remove aliasing from a pytorch program.
19
+ //
20
+ // This is useful for backends that don't support aliasing, like XLA and Vulkan.
21
+ // It's also necessary in order to remove mutation from a program, which is
22
+ // needed in Functorch.
23
+ //
24
+ // Consider this program:
25
+ // a = torch.ones(...)
26
+ // b = a.view(...)
27
+ // b.add_(1)
28
+ //
29
+ // In this program, b is meant to alias with a due to the use of view(). At the
30
+ // end of the program, both a and b are full of 2's. However, backends that
31
+ // don't support aliasing aren't able to correctly implement the view()
32
+ // operator. Instead, they can opt into the Functionalization pass, which will
33
+ // sit between the user and the backend, and provide the necessary aliasing
34
+ // logic.
35
+ //
36
+ // The functionalization pass will turn the above program into a slightly
37
+ // different program that has the same semantics, transparently to the user,
38
+ // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
39
+ // a.view_copy(...) # view() replaced with view_copy(). Backends like
40
+ // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
41
+ // pass machinery knows that a and b are aliased - it applies b's mutation to a
42
+ // too.
43
+ //
44
+ // So, how does the functionalization pass keep track of which tensors are
45
+ // aliased? The pass works by wrapping EVERY tensor in the program inside of a
46
+ // FunctionalTensorWrapper, which knows about its alias'd tensors.
47
+ //
48
+ // See Note [Functionalization: Alias Removal] for details on the aliasing
49
+ // machinery. See Note [Functionalization: Mutation Removal] for details on
50
+ // mutation removal.
51
+ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
52
+ explicit FunctionalTensorWrapper(const Tensor& value);
53
+ // Additional constructor to create a FunctionalTensorWrapper directly from an
54
+ // underlying tensor that was created from a view. For example, the code b =
55
+ // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
56
+ // view1_meta)
57
+ explicit FunctionalTensorWrapper(
58
+ const Tensor& view_value,
59
+ const FunctionalTensorWrapper* base,
60
+ const std::shared_ptr<functionalization::ViewMeta>& meta);
61
+
62
+ // Get the underlying, actual tensor, that doesn't know anything about
63
+ // functionalization.
64
+ const Tensor& value() const {
65
+ return value_;
66
+ }
67
+ // The concept of "level" is only ever important to functorch; it's exposed
68
+ // here as more of a hook for functorch to use.
69
+ int64_t level() const {
70
+ return level_;
71
+ }
72
+ void set_level(int64_t level) {
73
+ level_ = level;
74
+ }
75
+ bool has_metadata_mutation() const {
76
+ return has_metadata_mutation_;
77
+ }
78
+ uint64_t mutation_counter() const {
79
+ return functional_storage_impl()->mutation_counter();
80
+ }
81
+ void mark_mutation() {
82
+ functional_storage_impl()->mark_mutation();
83
+ }
84
+ // Denotes a mutation that's hidden from autograd,
85
+ // e.g. for the purposes of passing a tensor to a triton kernel
86
+ void mark_mutation_hidden_from_autograd() {
87
+ functional_storage_impl()->mark_mutation_hidden_from_autograd();
88
+ }
89
+ void mark_mutation_during_no_grad_or_inference_mode() {
90
+ functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
91
+ }
92
+ // Are all the mutations happening to the tensor hidden from autograd
93
+ bool are_all_mutations_hidden_from_autograd() const {
94
+ return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
95
+ }
96
+ // Did all mutations happen under no_grad or inference_mode
97
+ // (We also need to ignore mutations fully hidden from autograd here)
98
+ bool are_all_mutations_under_no_grad_or_inference_mode() const {
99
+ return functional_storage_impl()
100
+ ->are_all_mutations_under_no_grad_or_inference_mode();
101
+ }
102
+
103
+ void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
104
+ is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
105
+ }
106
+
107
+ bool is_symbolic() const {
108
+ return is_symbolic_;
109
+ }
110
+
111
+ // Retrieves the ViewMeta sequence of this tensor.
112
+ const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
113
+ const;
114
+
115
+ // Sync's the underlying tensor with its alias, if it's out of date. This
116
+ // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
117
+ // Replay the views (if any) to regenerate the current tensor off of the
118
+ // updated alias.
119
+ void sync_();
120
+ // Performs step (1) of the sync. This is its own public API because it's
121
+ // needed by view_inplace ops like transpose_. See Note [Functionalization
122
+ // Pass - Inplace View Ops]
123
+ void regenerate_from_base();
124
+ // Performs step (2) of the sync. This is its own public API because it's
125
+ // needed by functorch. functorch wants to make sure that all input tensors to
126
+ // a functionalized program have been properly synced so it can properly
127
+ // propagate mutations to inputs. It can't just call sync_(), because the
128
+ // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
129
+ // a noop. We use the reference count on storage_ to determine if the wrapper
130
+ // is aliased, and by the time functorch is ready to propagate updates to
131
+ // inputs, any intermediate views of the input created by the program will
132
+ // have been deallocated. This function also returns whether or not the base
133
+ // actually had any updates to apply.
134
+ bool apply_updates();
135
+ // Takes the current state of value_ and snapshots it, sending it as a pending
136
+ // update to the alias.
137
+ void commit_update();
138
+ // When any tensor is mutated, the tensor increments its alias's "generation".
139
+ // Separately, each tensor maintains its own "generation" counter, which is
140
+ // used to determine if it's up-to-date with its alias. The act of syncing a
141
+ // tensor will set a tensor's generation equal to its alias's generation.
142
+ bool is_up_to_date() const;
143
+ // Freezes the storage of this tensor, preventing subsequent mutations
144
+ void freeze_storage() const;
145
+ // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
146
+ // describing the series of view ops that ran to generate the current tensor
147
+ // from the base tensor. This method is used by inplace-view ops like
148
+ // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
149
+ // tensor by replaying the views off of the alias.
150
+ void mutate_view_meta(
151
+ const std::shared_ptr<at::functionalization::ViewMeta>& meta);
152
+
153
+ // Custom implementation of self.set_(src)
154
+ void set__impl(const FunctionalTensorWrapper* other);
155
+
156
+ // Custom implementation of resize_storage_bytes_(self, new_size)
157
+ void storage_resize_(const c10::SymInt& new_size);
158
+
159
+ // Returns whether the current tensor's data was ever mutated
160
+ bool has_data_mutation();
161
+ //
162
+ // Returns whether the current FunctionalTensorWrapper
163
+ // experienced a set_() call.
164
+ bool was_storage_changed() {
165
+ return was_storage_changed_;
166
+ }
167
+
168
+ void mark_storage_changed() {
169
+ was_storage_changed_ = true;
170
+ storage_changed_counter_++;
171
+ }
172
+
173
+ uint64_t storage_changed_counter() {
174
+ return storage_changed_counter_;
175
+ }
176
+
177
+ // A FunctionalTensor is considered a base if its not a view of another
178
+ // tensor.
179
+ bool isBaseTensor() const {
180
+ return view_metas_.empty();
181
+ }
182
+
183
+ c10::SymInt get_storage_size(bool before) {
184
+ return functional_storage_impl()->get_storage_size(before);
185
+ }
186
+
187
+ // Returns whether the FunctionalTensor experienced an
188
+ // untyped_storage().resize_() call
189
+ bool was_inductor_storage_resized() {
190
+ return functional_storage_impl()->was_inductor_storage_resized();
191
+ }
192
+
193
+ bool inductor_storage_resized_counter() {
194
+ return functional_storage_impl()->inductor_storage_resized_counter();
195
+ }
196
+ // The functionalization pass can be used to remove mutations.
197
+ // It does so by replacing any mutation op with it's corresponding
198
+ // out-of-place op, followed by a call to replace_(). e.g:
199
+ //
200
+ // a.add_(1)
201
+ //
202
+ // will turn into:
203
+ //
204
+ // tmp = a.add(1)
205
+ // a.replace_(tmp)
206
+ //
207
+ // replace_() swaps out the wrapped tensor, value_, with tmp.
208
+ void replace_(const Tensor& other, bool from_lazy_regenerate = false);
209
+
210
+ bool is_multi_output_view() {
211
+ return is_multi_output_view_;
212
+ }
213
+
214
+ // See Note[resize_() in functionalization pass]
215
+ void maybe_replace_storage(const Tensor& other);
216
+
217
+ // Replaces the storage with a new functional storage,
218
+ // and clears the view_metas_ stack.
219
+ // WARNING: Calling this function will sever the aliasing relationship between
220
+ // the current FunctionalTensorWrapper and any of its outstanding aliases.
221
+ // Please only call if you know what you're doing.
222
+ void _unsafe_reset_storage();
223
+
224
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
225
+ const c10::VariableVersion& version_counter,
226
+ bool allow_tensor_metadata_change) const override;
227
+
228
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
229
+ c10::VariableVersion&& version_counter,
230
+ bool allow_tensor_metadata_change) const override;
231
+
232
+ ~FunctionalTensorWrapper() override = default;
233
+
234
+ // FunctionalTensorWrapper overrides all custom size/stride function,
235
+ // so that if the inner tensor has a custom implementation
236
+ // we make sure to call that implementation.
237
+ at::IntArrayRef sizes_custom() const override;
238
+ at::IntArrayRef strides_custom() const override;
239
+ int64_t dim_custom() const override;
240
+ int64_t numel_custom() const override;
241
+ c10::SymBool sym_is_contiguous_custom(
242
+ at::MemoryFormat memory_format) const override;
243
+ c10::SymIntArrayRef sym_sizes_custom() const override;
244
+ c10::SymInt sym_size_custom(int64_t d) const override;
245
+ c10::SymIntArrayRef sym_strides_custom() const override;
246
+ c10::SymInt sym_storage_offset_custom() const override;
247
+ c10::Device device_custom() const override;
248
+ c10::Layout layout_impl() const override;
249
+
250
+ private:
251
+ const char* tensorimpl_type_name() const override;
252
+ void set_constructor_metadata();
253
+ functionalization::FunctionalStorageImpl* functional_storage_impl() const;
254
+
255
+ // This is used to re-implement shallow_copy_and_detach for
256
+ // FunctionalTensorWrapper. The implementation is identical, but we just need
257
+ // to return a subclass instead of a plain TensorImpl.
258
+ // TODO: maybe it's possible to arrange for that to happen automatically
259
+ // without an override here?
260
+ template <typename VariableVersion>
261
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
262
+ VariableVersion&& version_counter,
263
+ bool allow_tensor_metadata_change) const;
264
+
265
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
266
+ void copy_tensor_metadata_and_refresh(
267
+ const FunctionalTensorWrapper* src_impl,
268
+ FunctionalTensorWrapper* dest_impl,
269
+ const c10::VariableVersion& version_counter,
270
+ bool allow_tensor_metadata_change) const;
271
+
272
+ // Note that value is not taken by reference: internally, the wrapper will
273
+ // change the value tensor that it points to over time.
274
+ Tensor value_;
275
+ int64_t level_{};
276
+ // These two counters are used for identifying
277
+ // whether all the mutations on a given tensor are hidden from autograd or
278
+ // not. If we have an input mutation that is hidden from autograd, then once
279
+ // we convert the input mutation to a copy_() we know it will be safe to hide
280
+ // the copy_() from autograd as well.
281
+ bool has_metadata_mutation_ = false;
282
+ bool is_multi_output_view_ = false;
283
+ // Did the tensor experience a set_() call.
284
+ bool was_storage_changed_ = false;
285
+ uint64_t storage_changed_counter_ = 0;
286
+ // Did the tensor experience any view operation with symbolic int.
287
+ bool is_symbolic_ = false;
288
+
289
+ size_t generation_ = 0;
290
+ std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
291
+
292
+ protected:
293
+ static void copy_tensor_metadata(
294
+ const FunctionalTensorWrapper* src_impl,
295
+ FunctionalTensorWrapper* dest_impl,
296
+ const c10::VariableVersion& version_counter,
297
+ bool allow_tensor_metadata_change);
298
+ };
299
+
300
+ // Utility functions for the functionalization pass.
301
+
302
+ namespace functionalization {
303
+ namespace impl {
304
+
305
+ inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
306
+ const Tensor& tensor) {
307
+ auto functional_impl =
308
+ static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
309
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
310
+ return functional_impl;
311
+ }
312
+
313
+ TORCH_API bool isBaseTensor(const at::Tensor& tensor);
314
+
315
+ TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
316
+ TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
317
+ TORCH_API bool isFunctionalTensor(
318
+ const c10::List<std::optional<Tensor>>& t_list);
319
+ TORCH_API bool isFunctionalTensor(ITensorListRef list);
320
+
321
+ TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
322
+ TORCH_API std::optional<Tensor> to_functional_tensor(
323
+ const std::optional<Tensor>& tensor);
324
+ TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
325
+ const c10::List<std::optional<Tensor>>& t_list);
326
+ TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
327
+
328
+ TORCH_API void freeze_functional_tensor(const Tensor& tensor);
329
+
330
+ TORCH_API Tensor
331
+ from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
332
+ TORCH_API std::optional<Tensor> from_functional_tensor(
333
+ const std::optional<Tensor>& t,
334
+ bool assert_functional = true);
335
+ TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
336
+ const c10::List<std::optional<Tensor>>& t_list);
337
+ TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
338
+
339
+ TORCH_API void sync(const at::Tensor& t);
340
+ TORCH_API void sync(const std::optional<Tensor>& t);
341
+ TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
342
+ TORCH_API void sync(ITensorListRef t_list);
343
+
344
+ TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
345
+ TORCH_API void replace_(
346
+ const ITensorListRef functional_tensor,
347
+ ITensorListRef other);
348
+
349
+ TORCH_API void commit_update(const Tensor& functional_tensor);
350
+ TORCH_API void commit_update(ITensorListRef functional_tensor);
351
+
352
+ TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
353
+
354
+ TORCH_API void mark_mutation_hidden_from_autograd(
355
+ const Tensor& functional_tensor);
356
+
357
+ TORCH_API bool are_all_mutations_hidden_from_autograd(
358
+ const Tensor& functional_tensor);
359
+
360
+ TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
361
+ const Tensor& functional_tensor);
362
+
363
+ // These two methods are XLA-specific logic and are no-ops
364
+ // for the normal functionalization flow.
365
+ TORCH_API void propagate_xla_data(
366
+ const Tensor& functional_tensor,
367
+ const Tensor& other);
368
+ TORCH_API void propagate_xla_data(
369
+ const ITensorListRef functional_tensor,
370
+ ITensorListRef other);
371
+
372
+ TORCH_API void propagate_xla_data_direct(
373
+ const Tensor& tensor,
374
+ const Tensor& other);
375
+ TORCH_API void propagate_xla_data_direct(
376
+ const ITensorListRef tensor,
377
+ ITensorListRef other);
378
+
379
+ Tensor create_functional_tensor_with_view_meta(
380
+ const Tensor& view_to_wrap,
381
+ const Tensor& base,
382
+ const std::shared_ptr<functionalization::ViewMeta>& meta,
383
+ int64_t out_idx = 0);
384
+ std::vector<Tensor> create_functional_tensor_with_view_meta(
385
+ ITensorListRef view_to_wrap,
386
+ const Tensor& base,
387
+ const std::shared_ptr<functionalization::ViewMeta>& meta);
388
+
389
+ void mutate_view_meta(
390
+ const Tensor& self,
391
+ const std::shared_ptr<functionalization::ViewMeta>& meta);
392
+
393
+ TORCH_API Tensor apply_view_meta_sequence(
394
+ const Tensor& base,
395
+ const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
396
+
397
+ void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
398
+ void set_sizes_strides_offset(
399
+ const std::vector<Tensor>& outs,
400
+ const std::vector<Tensor>& meta_outs);
401
+
402
+ // ~~~~~ TLS used in functionalization ~~~~~
403
+
404
+ TORCH_API bool getFunctionalizationReapplyViewsTLS();
405
+ TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
406
+
407
+ class TORCH_API FunctionalizationReapplyViewsGuard {
408
+ public:
409
+ FunctionalizationReapplyViewsGuard(bool reapply_views)
410
+ : prev_(getFunctionalizationReapplyViewsTLS()) {
411
+ setFunctionalizationReapplyViewsTLS(reapply_views);
412
+ }
413
+
414
+ ~FunctionalizationReapplyViewsGuard() {
415
+ setFunctionalizationReapplyViewsTLS(prev_);
416
+ }
417
+
418
+ FunctionalizationReapplyViewsGuard(
419
+ const FunctionalizationReapplyViewsGuard&) = delete;
420
+ FunctionalizationReapplyViewsGuard operator=(
421
+ const FunctionalizationReapplyViewsGuard&) = delete;
422
+ FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
423
+ delete;
424
+ FunctionalizationReapplyViewsGuard operator=(
425
+ FunctionalizationReapplyViewsGuard&&) = delete;
426
+
427
+ private:
428
+ bool prev_;
429
+ };
430
+
431
+ } // namespace impl
432
+
433
+ // Helper function to call an out-of-place composite aten kernel that may use
434
+ // mutations / views internally, and functionalize them.
435
+ TORCH_API void functionalize_op_helper(
436
+ const c10::OperatorHandle& op,
437
+ torch::jit::Stack* stack);
438
+
439
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
440
+ struct _functionalize_aten_op final {};
441
+
442
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
443
+ struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
444
+ static ReturnType call(
445
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
446
+ using FuncType = ReturnType(
447
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
448
+ auto op = c10::Dispatcher::singleton()
449
+ .findSchemaOrThrow(
450
+ (const char*)Op::name, (const char*)Op::overload_name)
451
+ .typed<FuncType>();
452
+
453
+ return c10::impl::BoxedKernelWrapper<FuncType>::call(
454
+ c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
455
+ op,
456
+ // BoxedKernelWrapper knows to ignore this keyset argument,
457
+ // because functionalize_op_helper doesn't take in a DispatchKeySet
458
+ c10::DispatchKeySet(),
459
+ args...);
460
+ }
461
+ };
462
+
463
+ template <class Op>
464
+ using functionalize_aten_op =
465
+ _functionalize_aten_op<Op, false, typename Op::schema>;
466
+
467
+ template <class Op>
468
+ using functionalize_aten_op_symint =
469
+ _functionalize_aten_op<Op, true, typename Op::schema>;
470
+
471
+ } // namespace functionalization
472
+ } // namespace at
473
+
474
+ #else
475
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
476
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Functions.h ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from Functions.h
5
+
6
+ #ifdef TORCH_ASSERT_NO_OPERATORS
7
+ #error This change adds a dependency on native_functions.yaml, \
8
+ meaning the file will need to be re-compiled every time an operator \
9
+ is changed or added. Consider if your change would be better placed in \
10
+ another file, or if a more specific header might achieve the same goal. \
11
+ See NOTE: [Tensor vs. TensorBase]
12
+ #endif
13
+
14
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
15
+ #error This change adds a dependency on all pytorch operators, meaning the \
16
+ file will need to be re-compiled every time an operator is changed or added. \
17
+ Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
18
+ see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
19
+ #endif
20
+
21
+ // NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
22
+ //
23
+ // In ATen, certain generated headers files include the definitions of
24
+ // every single operator in PyTorch. Unfortunately this means every
25
+ // time an operator signature is updated or changed in
26
+ // native_functions.yaml, you (and every other PyTorch developer) need
27
+ // to recompile every source file that includes any of these headers.
28
+ //
29
+ // To break up these header dependencies, and improve incremental
30
+ // build times for all PyTorch developers. These headers are split
31
+ // into per-operator headers in the `ATen/ops` folder. This limits
32
+ // incremental builds to only changes to methods of `Tensor`, or files
33
+ // that use the specific operator being changed. With `at::sum` as an
34
+ // example, you should include
35
+ //
36
+ // <ATen/ops/sum.h> // instead of ATen/Functions.h
37
+ // <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
38
+ // <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
39
+ // <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
40
+ //
41
+ // However, even if you're careful to use this in your own code.
42
+ // `Functions.h` might be included indirectly through another header
43
+ // without you realising. To avoid this, you can add
44
+ //
45
+ // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
46
+ //
47
+ // to the top of your source file. This way any time the non-specific
48
+ // headers are included, the compiler will error out.
49
+ //
50
+ // Also, be aware that `ops` are not available in all build
51
+ // configurations (namely fb-internal) so you must guard these
52
+ // includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
53
+ //
54
+ // #ifndef AT_PER_OPERATOR_HEADERS
55
+ // #include <ATen/Functions.h>
56
+ // #else
57
+ // #include <ATen/ops/sum.h>
58
+ // #endif
59
+
60
+ #include <ATen/Context.h>
61
+ #include <ATen/DeviceGuard.h>
62
+ #include <ATen/TensorUtils.h>
63
+ #include <ATen/TracerMode.h>
64
+ #include <ATen/core/Generator.h>
65
+ #include <ATen/core/Reduction.h>
66
+ #include <c10/core/SymInt.h>
67
+ #include <ATen/core/Tensor.h>
68
+ #include <c10/core/Scalar.h>
69
+ #include <c10/core/Storage.h>
70
+ #include <c10/core/TensorOptions.h>
71
+ #include <c10/util/Deprecated.h>
72
+ #include <optional>
73
+ #include <c10/util/OptionalArrayRef.h>
74
+
75
+ #include <ATen/ops/from_blob.h>
76
+ #include <ATen/ops/tensor.h>
77
+
78
+ #include <ATen/ops/_adaptive_avg_pool2d.h>
79
+ #include <ATen/ops/_adaptive_avg_pool2d_backward.h>
80
+ #include <ATen/ops/_adaptive_avg_pool3d.h>
81
+ #include <ATen/ops/_adaptive_avg_pool3d_backward.h>
82
+ #include <ATen/ops/_add_batch_dim.h>
83
+ #include <ATen/ops/_add_relu.h>
84
+ #include <ATen/ops/_addmm_activation.h>
85
+ #include <ATen/ops/_aminmax.h>
86
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
87
+ #include <ATen/ops/_amp_update_scale.h>
88
+ #include <ATen/ops/_assert_async.h>
89
+ #include <ATen/ops/_assert_scalar.h>
90
+ #include <ATen/ops/_assert_tensor_metadata.h>
91
+ #include <ATen/ops/_autocast_to_full_precision.h>
92
+ #include <ATen/ops/_autocast_to_reduced_precision.h>
93
+ #include <ATen/ops/_backward.h>
94
+ #include <ATen/ops/_batch_norm_impl_index.h>
95
+ #include <ATen/ops/_batch_norm_impl_index_backward.h>
96
+ #include <ATen/ops/_batch_norm_no_update.h>
97
+ #include <ATen/ops/_batch_norm_with_update.h>
98
+ #include <ATen/ops/_cast_Byte.h>
99
+ #include <ATen/ops/_cast_Char.h>
100
+ #include <ATen/ops/_cast_Double.h>
101
+ #include <ATen/ops/_cast_Float.h>
102
+ #include <ATen/ops/_cast_Half.h>
103
+ #include <ATen/ops/_cast_Int.h>
104
+ #include <ATen/ops/_cast_Long.h>
105
+ #include <ATen/ops/_cast_Short.h>
106
+ #include <ATen/ops/_cdist_backward.h>
107
+ #include <ATen/ops/_cdist_forward.h>
108
+ #include <ATen/ops/_cholesky_solve_helper.h>
109
+ #include <ATen/ops/_choose_qparams_per_tensor.h>
110
+ #include <ATen/ops/_chunk_cat.h>
111
+ #include <ATen/ops/_coalesce.h>
112
+ #include <ATen/ops/_coalesced.h>
113
+ #include <ATen/ops/_compute_linear_combination.h>
114
+ #include <ATen/ops/_conj.h>
115
+ #include <ATen/ops/_conj_copy.h>
116
+ #include <ATen/ops/_conj_physical.h>
117
+ #include <ATen/ops/_conv_depthwise2d.h>
118
+ #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
119
+ #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
120
+ #include <ATen/ops/_convert_weight_to_int4pack.h>
121
+ #include <ATen/ops/_convert_weight_to_int4pack_for_cpu.h>
122
+ #include <ATen/ops/_convolution.h>
123
+ #include <ATen/ops/_convolution_double_backward.h>
124
+ #include <ATen/ops/_convolution_mode.h>
125
+ #include <ATen/ops/_copy_from.h>
126
+ #include <ATen/ops/_copy_from_and_resize.h>
127
+ #include <ATen/ops/_cslt_compress.h>
128
+ #include <ATen/ops/_cslt_sparse_mm.h>
129
+ #include <ATen/ops/_cslt_sparse_mm_search.h>
130
+ #include <ATen/ops/_ctc_loss.h>
131
+ #include <ATen/ops/_ctc_loss_backward.h>
132
+ #include <ATen/ops/_cudnn_attention_backward.h>
133
+ #include <ATen/ops/_cudnn_attention_forward.h>
134
+ #include <ATen/ops/_cudnn_ctc_loss.h>
135
+ #include <ATen/ops/_cudnn_init_dropout_state.h>
136
+ #include <ATen/ops/_cudnn_rnn.h>
137
+ #include <ATen/ops/_cudnn_rnn_backward.h>
138
+ #include <ATen/ops/_cudnn_rnn_flatten_weight.h>
139
+ #include <ATen/ops/_cufft_clear_plan_cache.h>
140
+ #include <ATen/ops/_cufft_get_plan_cache_max_size.h>
141
+ #include <ATen/ops/_cufft_get_plan_cache_size.h>
142
+ #include <ATen/ops/_cufft_set_plan_cache_max_size.h>
143
+ #include <ATen/ops/_cummax_helper.h>
144
+ #include <ATen/ops/_cummin_helper.h>
145
+ #include <ATen/ops/_debug_has_internal_overlap.h>
146
+ #include <ATen/ops/_dimI.h>
147
+ #include <ATen/ops/_dimV.h>
148
+ #include <ATen/ops/_dim_arange.h>
149
+ #include <ATen/ops/_dirichlet_grad.h>
150
+ #include <ATen/ops/_dyn_quant_matmul_4bit.h>
151
+ #include <ATen/ops/_dyn_quant_pack_4bit_weight.h>
152
+ #include <ATen/ops/_efficient_attention_backward.h>
153
+ #include <ATen/ops/_efficient_attention_forward.h>
154
+ #include <ATen/ops/_efficientzerotensor.h>
155
+ #include <ATen/ops/_embedding_bag.h>
156
+ #include <ATen/ops/_embedding_bag_backward.h>
157
+ #include <ATen/ops/_embedding_bag_dense_backward.h>
158
+ #include <ATen/ops/_embedding_bag_forward_only.h>
159
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
160
+ #include <ATen/ops/_embedding_bag_sparse_backward.h>
161
+ #include <ATen/ops/_empty_affine_quantized.h>
162
+ #include <ATen/ops/_empty_per_channel_affine_quantized.h>
163
+ #include <ATen/ops/_euclidean_dist.h>
164
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
165
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
166
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
167
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
168
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
169
+ #include <ATen/ops/_fft_c2c.h>
170
+ #include <ATen/ops/_fft_c2r.h>
171
+ #include <ATen/ops/_fft_r2c.h>
172
+ #include <ATen/ops/_fill_mem_eff_dropout_mask.h>
173
+ #include <ATen/ops/_flash_attention_backward.h>
174
+ #include <ATen/ops/_flash_attention_forward.h>
175
+ #include <ATen/ops/_foobar.h>
176
+ #include <ATen/ops/_foreach_abs.h>
177
+ #include <ATen/ops/_foreach_acos.h>
178
+ #include <ATen/ops/_foreach_add.h>
179
+ #include <ATen/ops/_foreach_addcdiv.h>
180
+ #include <ATen/ops/_foreach_addcmul.h>
181
+ #include <ATen/ops/_foreach_asin.h>
182
+ #include <ATen/ops/_foreach_atan.h>
183
+ #include <ATen/ops/_foreach_ceil.h>
184
+ #include <ATen/ops/_foreach_clamp_max.h>
185
+ #include <ATen/ops/_foreach_clamp_min.h>
186
+ #include <ATen/ops/_foreach_copy.h>
187
+ #include <ATen/ops/_foreach_cos.h>
188
+ #include <ATen/ops/_foreach_cosh.h>
189
+ #include <ATen/ops/_foreach_div.h>
190
+ #include <ATen/ops/_foreach_erf.h>
191
+ #include <ATen/ops/_foreach_erfc.h>
192
+ #include <ATen/ops/_foreach_exp.h>
193
+ #include <ATen/ops/_foreach_expm1.h>
194
+ #include <ATen/ops/_foreach_floor.h>
195
+ #include <ATen/ops/_foreach_frac.h>
196
+ #include <ATen/ops/_foreach_lerp.h>
197
+ #include <ATen/ops/_foreach_lgamma.h>
198
+ #include <ATen/ops/_foreach_log.h>
199
+ #include <ATen/ops/_foreach_log10.h>
200
+ #include <ATen/ops/_foreach_log1p.h>
201
+ #include <ATen/ops/_foreach_log2.h>
202
+ #include <ATen/ops/_foreach_max.h>
203
+ #include <ATen/ops/_foreach_maximum.h>
204
+ #include <ATen/ops/_foreach_minimum.h>
205
+ #include <ATen/ops/_foreach_mul.h>
206
+ #include <ATen/ops/_foreach_neg.h>
207
+ #include <ATen/ops/_foreach_norm.h>
208
+ #include <ATen/ops/_foreach_pow.h>
209
+ #include <ATen/ops/_foreach_reciprocal.h>
210
+ #include <ATen/ops/_foreach_round.h>
211
+ #include <ATen/ops/_foreach_rsqrt.h>
212
+ #include <ATen/ops/_foreach_sigmoid.h>
213
+ #include <ATen/ops/_foreach_sign.h>
214
+ #include <ATen/ops/_foreach_sin.h>
215
+ #include <ATen/ops/_foreach_sinh.h>
216
+ #include <ATen/ops/_foreach_sqrt.h>
217
+ #include <ATen/ops/_foreach_sub.h>
218
+ #include <ATen/ops/_foreach_tan.h>
219
+ #include <ATen/ops/_foreach_tanh.h>
220
+ #include <ATen/ops/_foreach_trunc.h>
221
+ #include <ATen/ops/_foreach_zero.h>
222
+ #include <ATen/ops/_functional_assert_async.h>
223
+ #include <ATen/ops/_functional_assert_scalar.h>
224
+ #include <ATen/ops/_functional_sym_constrain_range.h>
225
+ #include <ATen/ops/_functional_sym_constrain_range_for_size.h>
226
+ #include <ATen/ops/_fused_adagrad.h>
227
+ #include <ATen/ops/_fused_adam.h>
228
+ #include <ATen/ops/_fused_adamw.h>
229
+ #include <ATen/ops/_fused_dropout.h>
230
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
231
+ #include <ATen/ops/_fused_rms_norm.h>
232
+ #include <ATen/ops/_fused_rms_norm_backward.h>
233
+ #include <ATen/ops/_fused_sdp_choice.h>
234
+ #include <ATen/ops/_fused_sgd.h>
235
+ #include <ATen/ops/_fw_primal.h>
236
+ #include <ATen/ops/_fw_primal_copy.h>
237
+ #include <ATen/ops/_gather_sparse_backward.h>
238
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
239
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
240
+ #include <ATen/ops/_grouped_mm.h>
241
+ #include <ATen/ops/_has_compatible_shallow_copy_type.h>
242
+ #include <ATen/ops/_has_same_storage_numel.h>
243
+ #include <ATen/ops/_histogramdd_bin_edges.h>
244
+ #include <ATen/ops/_histogramdd_from_bin_cts.h>
245
+ #include <ATen/ops/_histogramdd_from_bin_tensors.h>
246
+ #include <ATen/ops/_index_put_impl.h>
247
+ #include <ATen/ops/_indices.h>
248
+ #include <ATen/ops/_indices_copy.h>
249
+ #include <ATen/ops/_int_mm.h>
250
+ #include <ATen/ops/_is_all_true.h>
251
+ #include <ATen/ops/_is_any_true.h>
252
+ #include <ATen/ops/_is_zerotensor.h>
253
+ #include <ATen/ops/_jagged_to_padded_dense_forward.h>
254
+ #include <ATen/ops/_lazy_clone.h>
255
+ #include <ATen/ops/_linalg_check_errors.h>
256
+ #include <ATen/ops/_linalg_det.h>
257
+ #include <ATen/ops/_linalg_eigh.h>
258
+ #include <ATen/ops/_linalg_eigvals.h>
259
+ #include <ATen/ops/_linalg_slogdet.h>
260
+ #include <ATen/ops/_linalg_solve_ex.h>
261
+ #include <ATen/ops/_linalg_svd.h>
262
+ #include <ATen/ops/_local_scalar_dense.h>
263
+ #include <ATen/ops/_log_softmax.h>
264
+ #include <ATen/ops/_log_softmax_backward_data.h>
265
+ #include <ATen/ops/_logcumsumexp.h>
266
+ #include <ATen/ops/_lstm_mps.h>
267
+ #include <ATen/ops/_lu_with_info.h>
268
+ #include <ATen/ops/_make_dep_token.h>
269
+ #include <ATen/ops/_make_dual.h>
270
+ #include <ATen/ops/_make_dual_copy.h>
271
+ #include <ATen/ops/_make_per_channel_quantized_tensor.h>
272
+ #include <ATen/ops/_make_per_tensor_quantized_tensor.h>
273
+ #include <ATen/ops/_masked_scale.h>
274
+ #include <ATen/ops/_masked_softmax.h>
275
+ #include <ATen/ops/_masked_softmax_backward.h>
276
+ #include <ATen/ops/_mixed_dtypes_linear.h>
277
+ #include <ATen/ops/_mkldnn_reshape.h>
278
+ #include <ATen/ops/_mkldnn_transpose.h>
279
+ #include <ATen/ops/_mps_convolution.h>
280
+ #include <ATen/ops/_mps_convolution_transpose.h>
281
+ #include <ATen/ops/_native_batch_norm_legit.h>
282
+ #include <ATen/ops/_native_batch_norm_legit_no_training.h>
283
+ #include <ATen/ops/_native_multi_head_attention.h>
284
+ #include <ATen/ops/_neg_view.h>
285
+ #include <ATen/ops/_neg_view_copy.h>
286
+ #include <ATen/ops/_nested_compute_contiguous_strides_offsets.h>
287
+ #include <ATen/ops/_nested_from_padded.h>
288
+ #include <ATen/ops/_nested_from_padded_and_nested_example.h>
289
+ #include <ATen/ops/_nested_from_padded_tensor.h>
290
+ #include <ATen/ops/_nested_get_jagged_dummy.h>
291
+ #include <ATen/ops/_nested_get_lengths.h>
292
+ #include <ATen/ops/_nested_get_max_seqlen.h>
293
+ #include <ATen/ops/_nested_get_min_seqlen.h>
294
+ #include <ATen/ops/_nested_get_offsets.h>
295
+ #include <ATen/ops/_nested_get_ragged_idx.h>
296
+ #include <ATen/ops/_nested_get_values.h>
297
+ #include <ATen/ops/_nested_get_values_copy.h>
298
+ #include <ATen/ops/_nested_select_backward.h>
299
+ #include <ATen/ops/_nested_sum_backward.h>
300
+ #include <ATen/ops/_nested_tensor_from_mask.h>
301
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
302
+ #include <ATen/ops/_nested_tensor_from_tensor_list.h>
303
+ #include <ATen/ops/_nested_tensor_size.h>
304
+ #include <ATen/ops/_nested_tensor_softmax_with_shape.h>
305
+ #include <ATen/ops/_nested_tensor_storage_offsets.h>
306
+ #include <ATen/ops/_nested_tensor_strides.h>
307
+ #include <ATen/ops/_nested_view_from_buffer.h>
308
+ #include <ATen/ops/_nested_view_from_buffer_copy.h>
309
+ #include <ATen/ops/_nested_view_from_jagged.h>
310
+ #include <ATen/ops/_nested_view_from_jagged_copy.h>
311
+ #include <ATen/ops/_new_zeros_with_same_feature_meta.h>
312
+ #include <ATen/ops/_nnpack_available.h>
313
+ #include <ATen/ops/_nnpack_spatial_convolution.h>
314
+ #include <ATen/ops/_nnz.h>
315
+ #include <ATen/ops/_pack_padded_sequence.h>
316
+ #include <ATen/ops/_pack_padded_sequence_backward.h>
317
+ #include <ATen/ops/_pad_circular.h>
318
+ #include <ATen/ops/_pad_enum.h>
319
+ #include <ATen/ops/_pad_packed_sequence.h>
320
+ #include <ATen/ops/_padded_dense_to_jagged_forward.h>
321
+ #include <ATen/ops/_pdist_backward.h>
322
+ #include <ATen/ops/_pdist_forward.h>
323
+ #include <ATen/ops/_pin_memory.h>
324
+ #include <ATen/ops/_prelu_kernel.h>
325
+ #include <ATen/ops/_prelu_kernel_backward.h>
326
+ #include <ATen/ops/_print.h>
327
+ #include <ATen/ops/_propagate_xla_data.h>
328
+ #include <ATen/ops/_remove_batch_dim.h>
329
+ #include <ATen/ops/_reshape_alias.h>
330
+ #include <ATen/ops/_reshape_alias_copy.h>
331
+ #include <ATen/ops/_reshape_copy.h>
332
+ #include <ATen/ops/_reshape_from_tensor.h>
333
+ #include <ATen/ops/_resize_output.h>
334
+ #include <ATen/ops/_rowwise_prune.h>
335
+ #include <ATen/ops/_safe_softmax.h>
336
+ #include <ATen/ops/_sample_dirichlet.h>
337
+ #include <ATen/ops/_saturate_weight_to_fp16.h>
338
+ #include <ATen/ops/_scaled_dot_product_attention_math.h>
339
+ #include <ATen/ops/_scaled_dot_product_attention_math_for_mps.h>
340
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
341
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_backward.h>
342
+ #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
343
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
344
+ #include <ATen/ops/_scaled_dot_product_flash_attention.h>
345
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
346
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
347
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
348
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
349
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
350
+ #include <ATen/ops/_scaled_grouped_mm.h>
351
+ #include <ATen/ops/_scaled_grouped_mm_v2.h>
352
+ #include <ATen/ops/_scaled_mm.h>
353
+ #include <ATen/ops/_scaled_mm_v2.h>
354
+ #include <ATen/ops/_segment_reduce_backward.h>
355
+ #include <ATen/ops/_shape_as_tensor.h>
356
+ #include <ATen/ops/_slow_conv2d_backward.h>
357
+ #include <ATen/ops/_slow_conv2d_forward.h>
358
+ #include <ATen/ops/_sobol_engine_draw.h>
359
+ #include <ATen/ops/_sobol_engine_ff.h>
360
+ #include <ATen/ops/_sobol_engine_initialize_state.h>
361
+ #include <ATen/ops/_sobol_engine_scramble.h>
362
+ #include <ATen/ops/_softmax.h>
363
+ #include <ATen/ops/_softmax_backward_data.h>
364
+ #include <ATen/ops/_sparse_addmm.h>
365
+ #include <ATen/ops/_sparse_broadcast_to.h>
366
+ #include <ATen/ops/_sparse_broadcast_to_copy.h>
367
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
368
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
369
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
370
+ #include <ATen/ops/_sparse_compressed_tensor_with_dims.h>
371
+ #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
372
+ #include <ATen/ops/_sparse_coo_tensor_with_dims.h>
373
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
374
+ #include <ATen/ops/_sparse_csc_tensor_unsafe.h>
375
+ #include <ATen/ops/_sparse_csr_prod.h>
376
+ #include <ATen/ops/_sparse_csr_sum.h>
377
+ #include <ATen/ops/_sparse_csr_tensor_unsafe.h>
378
+ #include <ATen/ops/_sparse_log_softmax.h>
379
+ #include <ATen/ops/_sparse_log_softmax_backward_data.h>
380
+ #include <ATen/ops/_sparse_mask_projection.h>
381
+ #include <ATen/ops/_sparse_mm.h>
382
+ #include <ATen/ops/_sparse_mm_reduce_impl.h>
383
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
384
+ #include <ATen/ops/_sparse_semi_structured_addmm.h>
385
+ #include <ATen/ops/_sparse_semi_structured_apply.h>
386
+ #include <ATen/ops/_sparse_semi_structured_apply_dense.h>
387
+ #include <ATen/ops/_sparse_semi_structured_linear.h>
388
+ #include <ATen/ops/_sparse_semi_structured_mm.h>
389
+ #include <ATen/ops/_sparse_semi_structured_tile.h>
390
+ #include <ATen/ops/_sparse_softmax.h>
391
+ #include <ATen/ops/_sparse_softmax_backward_data.h>
392
+ #include <ATen/ops/_sparse_sparse_matmul.h>
393
+ #include <ATen/ops/_sparse_sum.h>
394
+ #include <ATen/ops/_sparse_sum_backward.h>
395
+ #include <ATen/ops/_spdiags.h>
396
+ #include <ATen/ops/_spsolve.h>
397
+ #include <ATen/ops/_stack.h>
398
+ #include <ATen/ops/_standard_gamma.h>
399
+ #include <ATen/ops/_standard_gamma_grad.h>
400
+ #include <ATen/ops/_test_ambiguous_defaults.h>
401
+ #include <ATen/ops/_test_autograd_multiple_dispatch.h>
402
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
403
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
404
+ #include <ATen/ops/_test_check_tensor.h>
405
+ #include <ATen/ops/_test_functorch_fallback.h>
406
+ #include <ATen/ops/_test_optional_filled_intlist.h>
407
+ #include <ATen/ops/_test_optional_floatlist.h>
408
+ #include <ATen/ops/_test_optional_intlist.h>
409
+ #include <ATen/ops/_test_parallel_materialize.h>
410
+ #include <ATen/ops/_test_serialization_subcmul.h>
411
+ #include <ATen/ops/_test_string_default.h>
412
+ #include <ATen/ops/_test_warn_in_autograd.h>
413
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
414
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
415
+ #include <ATen/ops/_thnn_fused_gru_cell.h>
416
+ #include <ATen/ops/_thnn_fused_gru_cell_backward.h>
417
+ #include <ATen/ops/_thnn_fused_lstm_cell.h>
418
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
419
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
420
+ #include <ATen/ops/_to_copy.h>
421
+ #include <ATen/ops/_to_cpu.h>
422
+ #include <ATen/ops/_to_dense.h>
423
+ #include <ATen/ops/_to_sparse.h>
424
+ #include <ATen/ops/_to_sparse_bsc.h>
425
+ #include <ATen/ops/_to_sparse_bsr.h>
426
+ #include <ATen/ops/_to_sparse_csc.h>
427
+ #include <ATen/ops/_to_sparse_csr.h>
428
+ #include <ATen/ops/_to_sparse_semi_structured.h>
429
+ #include <ATen/ops/_transform_bias_rescale_qkv.h>
430
+ #include <ATen/ops/_transformer_encoder_layer_fwd.h>
431
+ #include <ATen/ops/_trilinear.h>
432
+ #include <ATen/ops/_triton_multi_head_attention.h>
433
+ #include <ATen/ops/_triton_scaled_dot_attention.h>
434
+ #include <ATen/ops/_unique.h>
435
+ #include <ATen/ops/_unique2.h>
436
+ #include <ATen/ops/_unpack_dual.h>
437
+ #include <ATen/ops/_unsafe_index.h>
438
+ #include <ATen/ops/_unsafe_index_put.h>
439
+ #include <ATen/ops/_unsafe_masked_index.h>
440
+ #include <ATen/ops/_unsafe_masked_index_put_accumulate.h>
441
+ #include <ATen/ops/_unsafe_view.h>
442
+ #include <ATen/ops/_upsample_bicubic2d_aa.h>
443
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
444
+ #include <ATen/ops/_upsample_bilinear2d_aa.h>
445
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
446
+ #include <ATen/ops/_upsample_nearest_exact1d.h>
447
+ #include <ATen/ops/_upsample_nearest_exact1d_backward.h>
448
+ #include <ATen/ops/_upsample_nearest_exact2d.h>
449
+ #include <ATen/ops/_upsample_nearest_exact2d_backward.h>
450
+ #include <ATen/ops/_upsample_nearest_exact3d.h>
451
+ #include <ATen/ops/_upsample_nearest_exact3d_backward.h>
452
+ #include <ATen/ops/_use_cudnn_ctc_loss.h>
453
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
454
+ #include <ATen/ops/_validate_compressed_sparse_indices.h>
455
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
456
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
457
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
458
+ #include <ATen/ops/_validate_sparse_coo_tensor_args.h>
459
+ #include <ATen/ops/_validate_sparse_csc_tensor_args.h>
460
+ #include <ATen/ops/_validate_sparse_csr_tensor_args.h>
461
+ #include <ATen/ops/_values.h>
462
+ #include <ATen/ops/_values_copy.h>
463
+ #include <ATen/ops/_version.h>
464
+ #include <ATen/ops/_weight_int4pack_mm.h>
465
+ #include <ATen/ops/_weight_int4pack_mm_for_cpu.h>
466
+ #include <ATen/ops/_weight_int4pack_mm_with_scales_and_zeros.h>
467
+ #include <ATen/ops/_weight_int8pack_mm.h>
468
+ #include <ATen/ops/_weight_norm.h>
469
+ #include <ATen/ops/_weight_norm_differentiable_backward.h>
470
+ #include <ATen/ops/_weight_norm_interface.h>
471
+ #include <ATen/ops/_weight_norm_interface_backward.h>
472
+ #include <ATen/ops/_wrapped_linear_prepack.h>
473
+ #include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
474
+ #include <ATen/ops/abs.h>
475
+ #include <ATen/ops/absolute.h>
476
+ #include <ATen/ops/acos.h>
477
+ #include <ATen/ops/acosh.h>
478
+ #include <ATen/ops/adaptive_avg_pool1d.h>
479
+ #include <ATen/ops/adaptive_avg_pool2d.h>
480
+ #include <ATen/ops/adaptive_avg_pool3d.h>
481
+ #include <ATen/ops/adaptive_avg_pool3d_backward.h>
482
+ #include <ATen/ops/adaptive_max_pool1d.h>
483
+ #include <ATen/ops/adaptive_max_pool2d.h>
484
+ #include <ATen/ops/adaptive_max_pool2d_backward.h>
485
+ #include <ATen/ops/adaptive_max_pool3d.h>
486
+ #include <ATen/ops/adaptive_max_pool3d_backward.h>
487
+ #include <ATen/ops/add.h>
488
+ #include <ATen/ops/addbmm.h>
489
+ #include <ATen/ops/addcdiv.h>
490
+ #include <ATen/ops/addcmul.h>
491
+ #include <ATen/ops/addmm.h>
492
+ #include <ATen/ops/addmv.h>
493
+ #include <ATen/ops/addr.h>
494
+ #include <ATen/ops/adjoint.h>
495
+ #include <ATen/ops/affine_grid_generator.h>
496
+ #include <ATen/ops/affine_grid_generator_backward.h>
497
+ #include <ATen/ops/alias.h>
498
+ #include <ATen/ops/alias_copy.h>
499
+ #include <ATen/ops/align_as.h>
500
+ #include <ATen/ops/align_tensors.h>
501
+ #include <ATen/ops/align_to.h>
502
+ #include <ATen/ops/all.h>
503
+ #include <ATen/ops/allclose.h>
504
+ #include <ATen/ops/alpha_dropout.h>
505
+ #include <ATen/ops/amax.h>
506
+ #include <ATen/ops/amin.h>
507
+ #include <ATen/ops/aminmax.h>
508
+ #include <ATen/ops/and.h>
509
+ #include <ATen/ops/angle.h>
510
+ #include <ATen/ops/any.h>
511
+ #include <ATen/ops/arange.h>
512
+ #include <ATen/ops/arccos.h>
513
+ #include <ATen/ops/arccosh.h>
514
+ #include <ATen/ops/arcsin.h>
515
+ #include <ATen/ops/arcsinh.h>
516
+ #include <ATen/ops/arctan.h>
517
+ #include <ATen/ops/arctan2.h>
518
+ #include <ATen/ops/arctanh.h>
519
+ #include <ATen/ops/argmax.h>
520
+ #include <ATen/ops/argmin.h>
521
+ #include <ATen/ops/argsort.h>
522
+ #include <ATen/ops/argwhere.h>
523
+ #include <ATen/ops/as_strided.h>
524
+ #include <ATen/ops/as_strided_copy.h>
525
+ #include <ATen/ops/as_strided_scatter.h>
526
+ #include <ATen/ops/asin.h>
527
+ #include <ATen/ops/asinh.h>
528
+ #include <ATen/ops/atan.h>
529
+ #include <ATen/ops/atan2.h>
530
+ #include <ATen/ops/atanh.h>
531
+ #include <ATen/ops/atleast_1d.h>
532
+ #include <ATen/ops/atleast_2d.h>
533
+ #include <ATen/ops/atleast_3d.h>
534
+ #include <ATen/ops/avg_pool1d.h>
535
+ #include <ATen/ops/avg_pool2d.h>
536
+ #include <ATen/ops/avg_pool2d_backward.h>
537
+ #include <ATen/ops/avg_pool3d.h>
538
+ #include <ATen/ops/avg_pool3d_backward.h>
539
+ #include <ATen/ops/baddbmm.h>
540
+ #include <ATen/ops/bartlett_window.h>
541
+ #include <ATen/ops/batch_norm.h>
542
+ #include <ATen/ops/batch_norm_backward.h>
543
+ #include <ATen/ops/batch_norm_backward_elemt.h>
544
+ #include <ATen/ops/batch_norm_backward_reduce.h>
545
+ #include <ATen/ops/batch_norm_elemt.h>
546
+ #include <ATen/ops/batch_norm_gather_stats.h>
547
+ #include <ATen/ops/batch_norm_gather_stats_with_counts.h>
548
+ #include <ATen/ops/batch_norm_stats.h>
549
+ #include <ATen/ops/batch_norm_update_stats.h>
550
+ #include <ATen/ops/bernoulli.h>
551
+ #include <ATen/ops/bilinear.h>
552
+ #include <ATen/ops/binary_cross_entropy.h>
553
+ #include <ATen/ops/binary_cross_entropy_backward.h>
554
+ #include <ATen/ops/binary_cross_entropy_with_logits.h>
555
+ #include <ATen/ops/bincount.h>
556
+ #include <ATen/ops/binomial.h>
557
+ #include <ATen/ops/bitwise_and.h>
558
+ #include <ATen/ops/bitwise_left_shift.h>
559
+ #include <ATen/ops/bitwise_not.h>
560
+ #include <ATen/ops/bitwise_or.h>
561
+ #include <ATen/ops/bitwise_right_shift.h>
562
+ #include <ATen/ops/bitwise_xor.h>
563
+ #include <ATen/ops/blackman_window.h>
564
+ #include <ATen/ops/block_diag.h>
565
+ #include <ATen/ops/bmm.h>
566
+ #include <ATen/ops/broadcast_tensors.h>
567
+ #include <ATen/ops/broadcast_to.h>
568
+ #include <ATen/ops/bucketize.h>
569
+ #include <ATen/ops/can_cast.h>
570
+ #include <ATen/ops/cartesian_prod.h>
571
+ #include <ATen/ops/cat.h>
572
+ #include <ATen/ops/cauchy.h>
573
+ #include <ATen/ops/ccol_indices.h>
574
+ #include <ATen/ops/ccol_indices_copy.h>
575
+ #include <ATen/ops/cdist.h>
576
+ #include <ATen/ops/ceil.h>
577
+ #include <ATen/ops/celu.h>
578
+ #include <ATen/ops/chain_matmul.h>
579
+ #include <ATen/ops/chalf.h>
580
+ #include <ATen/ops/channel_shuffle.h>
581
+ #include <ATen/ops/cholesky.h>
582
+ #include <ATen/ops/cholesky_inverse.h>
583
+ #include <ATen/ops/cholesky_solve.h>
584
+ #include <ATen/ops/choose_qparams_optimized.h>
585
+ #include <ATen/ops/chunk.h>
586
+ #include <ATen/ops/clamp.h>
587
+ #include <ATen/ops/clamp_max.h>
588
+ #include <ATen/ops/clamp_min.h>
589
+ #include <ATen/ops/clip.h>
590
+ #include <ATen/ops/clone.h>
591
+ #include <ATen/ops/coalesce.h>
592
+ #include <ATen/ops/col2im.h>
593
+ #include <ATen/ops/col_indices.h>
594
+ #include <ATen/ops/col_indices_copy.h>
595
+ #include <ATen/ops/column_stack.h>
596
+ #include <ATen/ops/combinations.h>
597
+ #include <ATen/ops/complex.h>
598
+ #include <ATen/ops/concat.h>
599
+ #include <ATen/ops/concatenate.h>
600
+ #include <ATen/ops/conj.h>
601
+ #include <ATen/ops/conj_physical.h>
602
+ #include <ATen/ops/constant_pad_nd.h>
603
+ #include <ATen/ops/contiguous.h>
604
+ #include <ATen/ops/conv1d.h>
605
+ #include <ATen/ops/conv2d.h>
606
+ #include <ATen/ops/conv3d.h>
607
+ #include <ATen/ops/conv_depthwise3d.h>
608
+ #include <ATen/ops/conv_tbc.h>
609
+ #include <ATen/ops/conv_tbc_backward.h>
610
+ #include <ATen/ops/conv_transpose1d.h>
611
+ #include <ATen/ops/conv_transpose2d.h>
612
+ #include <ATen/ops/conv_transpose3d.h>
613
+ #include <ATen/ops/convolution.h>
614
+ #include <ATen/ops/convolution_backward.h>
615
+ #include <ATen/ops/convolution_backward_overrideable.h>
616
+ #include <ATen/ops/convolution_overrideable.h>
617
+ #include <ATen/ops/copy.h>
618
+ #include <ATen/ops/copy_sparse_to_sparse.h>
619
+ #include <ATen/ops/copysign.h>
620
+ #include <ATen/ops/corrcoef.h>
621
+ #include <ATen/ops/cos.h>
622
+ #include <ATen/ops/cosh.h>
623
+ #include <ATen/ops/cosine_embedding_loss.h>
624
+ #include <ATen/ops/cosine_similarity.h>
625
+ #include <ATen/ops/count_nonzero.h>
626
+ #include <ATen/ops/cov.h>
627
+ #include <ATen/ops/cross.h>
628
+ #include <ATen/ops/cross_entropy_loss.h>
629
+ #include <ATen/ops/crow_indices.h>
630
+ #include <ATen/ops/crow_indices_copy.h>
631
+ #include <ATen/ops/ctc_loss.h>
632
+ #include <ATen/ops/cudnn_affine_grid_generator.h>
633
+ #include <ATen/ops/cudnn_affine_grid_generator_backward.h>
634
+ #include <ATen/ops/cudnn_batch_norm.h>
635
+ #include <ATen/ops/cudnn_batch_norm_backward.h>
636
+ #include <ATen/ops/cudnn_convolution.h>
637
+ #include <ATen/ops/cudnn_convolution_add_relu.h>
638
+ #include <ATen/ops/cudnn_convolution_relu.h>
639
+ #include <ATen/ops/cudnn_convolution_transpose.h>
640
+ #include <ATen/ops/cudnn_grid_sampler.h>
641
+ #include <ATen/ops/cudnn_grid_sampler_backward.h>
642
+ #include <ATen/ops/cudnn_is_acceptable.h>
643
+ #include <ATen/ops/cummax.h>
644
+ #include <ATen/ops/cummaxmin_backward.h>
645
+ #include <ATen/ops/cummin.h>
646
+ #include <ATen/ops/cumprod.h>
647
+ #include <ATen/ops/cumprod_backward.h>
648
+ #include <ATen/ops/cumsum.h>
649
+ #include <ATen/ops/cumulative_trapezoid.h>
650
+ #include <ATen/ops/data.h>
651
+ #include <ATen/ops/deg2rad.h>
652
+ #include <ATen/ops/dense_dim.h>
653
+ #include <ATen/ops/dequantize.h>
654
+ #include <ATen/ops/det.h>
655
+ #include <ATen/ops/detach.h>
656
+ #include <ATen/ops/detach_copy.h>
657
+ #include <ATen/ops/diag.h>
658
+ #include <ATen/ops/diag_embed.h>
659
+ #include <ATen/ops/diagflat.h>
660
+ #include <ATen/ops/diagonal.h>
661
+ #include <ATen/ops/diagonal_backward.h>
662
+ #include <ATen/ops/diagonal_copy.h>
663
+ #include <ATen/ops/diagonal_scatter.h>
664
+ #include <ATen/ops/diff.h>
665
+ #include <ATen/ops/digamma.h>
666
+ #include <ATen/ops/dist.h>
667
+ #include <ATen/ops/div.h>
668
+ #include <ATen/ops/divide.h>
669
+ #include <ATen/ops/dot.h>
670
+ #include <ATen/ops/dropout.h>
671
+ #include <ATen/ops/dsplit.h>
672
+ #include <ATen/ops/dstack.h>
673
+ #include <ATen/ops/einsum.h>
674
+ #include <ATen/ops/elu.h>
675
+ #include <ATen/ops/elu_backward.h>
676
+ #include <ATen/ops/embedding.h>
677
+ #include <ATen/ops/embedding_backward.h>
678
+ #include <ATen/ops/embedding_bag.h>
679
+ #include <ATen/ops/embedding_dense_backward.h>
680
+ #include <ATen/ops/embedding_renorm.h>
681
+ #include <ATen/ops/embedding_sparse_backward.h>
682
+ #include <ATen/ops/empty.h>
683
+ #include <ATen/ops/empty_like.h>
684
+ #include <ATen/ops/empty_permuted.h>
685
+ #include <ATen/ops/empty_quantized.h>
686
+ #include <ATen/ops/empty_strided.h>
687
+ #include <ATen/ops/eq.h>
688
+ #include <ATen/ops/equal.h>
689
+ #include <ATen/ops/erf.h>
690
+ #include <ATen/ops/erfc.h>
691
+ #include <ATen/ops/erfinv.h>
692
+ #include <ATen/ops/exp.h>
693
+ #include <ATen/ops/exp2.h>
694
+ #include <ATen/ops/expand.h>
695
+ #include <ATen/ops/expand_as.h>
696
+ #include <ATen/ops/expand_copy.h>
697
+ #include <ATen/ops/expm1.h>
698
+ #include <ATen/ops/exponential.h>
699
+ #include <ATen/ops/eye.h>
700
+ #include <ATen/ops/fake_quantize_per_channel_affine.h>
701
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
702
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
703
+ #include <ATen/ops/fake_quantize_per_tensor_affine.h>
704
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
705
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
706
+ #include <ATen/ops/fbgemm_linear_fp16_weight.h>
707
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
708
+ #include <ATen/ops/fbgemm_linear_int8_weight.h>
709
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
710
+ #include <ATen/ops/fbgemm_linear_quantize_weight.h>
711
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
712
+ #include <ATen/ops/fbgemm_pack_quantized_matrix.h>
713
+ #include <ATen/ops/feature_alpha_dropout.h>
714
+ #include <ATen/ops/feature_dropout.h>
715
+ #include <ATen/ops/fft_fft.h>
716
+ #include <ATen/ops/fft_fft2.h>
717
+ #include <ATen/ops/fft_fftfreq.h>
718
+ #include <ATen/ops/fft_fftn.h>
719
+ #include <ATen/ops/fft_fftshift.h>
720
+ #include <ATen/ops/fft_hfft.h>
721
+ #include <ATen/ops/fft_hfft2.h>
722
+ #include <ATen/ops/fft_hfftn.h>
723
+ #include <ATen/ops/fft_ifft.h>
724
+ #include <ATen/ops/fft_ifft2.h>
725
+ #include <ATen/ops/fft_ifftn.h>
726
+ #include <ATen/ops/fft_ifftshift.h>
727
+ #include <ATen/ops/fft_ihfft.h>
728
+ #include <ATen/ops/fft_ihfft2.h>
729
+ #include <ATen/ops/fft_ihfftn.h>
730
+ #include <ATen/ops/fft_irfft.h>
731
+ #include <ATen/ops/fft_irfft2.h>
732
+ #include <ATen/ops/fft_irfftn.h>
733
+ #include <ATen/ops/fft_rfft.h>
734
+ #include <ATen/ops/fft_rfft2.h>
735
+ #include <ATen/ops/fft_rfftfreq.h>
736
+ #include <ATen/ops/fft_rfftn.h>
737
+ #include <ATen/ops/fill.h>
738
+ #include <ATen/ops/fill_diagonal.h>
739
+ #include <ATen/ops/fix.h>
740
+ #include <ATen/ops/flatten.h>
741
+ #include <ATen/ops/flatten_dense_tensors.h>
742
+ #include <ATen/ops/flip.h>
743
+ #include <ATen/ops/fliplr.h>
744
+ #include <ATen/ops/flipud.h>
745
+ #include <ATen/ops/float_power.h>
746
+ #include <ATen/ops/floor.h>
747
+ #include <ATen/ops/floor_divide.h>
748
+ #include <ATen/ops/fmax.h>
749
+ #include <ATen/ops/fmin.h>
750
+ #include <ATen/ops/fmod.h>
751
+ #include <ATen/ops/frac.h>
752
+ #include <ATen/ops/fractional_max_pool2d.h>
753
+ #include <ATen/ops/fractional_max_pool2d_backward.h>
754
+ #include <ATen/ops/fractional_max_pool3d.h>
755
+ #include <ATen/ops/fractional_max_pool3d_backward.h>
756
+ #include <ATen/ops/frexp.h>
757
+ #include <ATen/ops/frobenius_norm.h>
758
+ #include <ATen/ops/from_file.h>
759
+ #include <ATen/ops/full.h>
760
+ #include <ATen/ops/full_like.h>
761
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
762
+ #include <ATen/ops/gather.h>
763
+ #include <ATen/ops/gather_backward.h>
764
+ #include <ATen/ops/gcd.h>
765
+ #include <ATen/ops/ge.h>
766
+ #include <ATen/ops/gelu.h>
767
+ #include <ATen/ops/gelu_backward.h>
768
+ #include <ATen/ops/geometric.h>
769
+ #include <ATen/ops/geqrf.h>
770
+ #include <ATen/ops/ger.h>
771
+ #include <ATen/ops/glu.h>
772
+ #include <ATen/ops/glu_backward.h>
773
+ #include <ATen/ops/glu_backward_jvp.h>
774
+ #include <ATen/ops/glu_jvp.h>
775
+ #include <ATen/ops/gradient.h>
776
+ #include <ATen/ops/greater.h>
777
+ #include <ATen/ops/greater_equal.h>
778
+ #include <ATen/ops/grid_sampler.h>
779
+ #include <ATen/ops/grid_sampler_2d.h>
780
+ #include <ATen/ops/grid_sampler_2d_backward.h>
781
+ #include <ATen/ops/grid_sampler_3d.h>
782
+ #include <ATen/ops/grid_sampler_3d_backward.h>
783
+ #include <ATen/ops/group_norm.h>
784
+ #include <ATen/ops/gru.h>
785
+ #include <ATen/ops/gru_cell.h>
786
+ #include <ATen/ops/gt.h>
787
+ #include <ATen/ops/hamming_window.h>
788
+ #include <ATen/ops/hann_window.h>
789
+ #include <ATen/ops/hardshrink.h>
790
+ #include <ATen/ops/hardshrink_backward.h>
791
+ #include <ATen/ops/hardsigmoid.h>
792
+ #include <ATen/ops/hardsigmoid_backward.h>
793
+ #include <ATen/ops/hardswish.h>
794
+ #include <ATen/ops/hardswish_backward.h>
795
+ #include <ATen/ops/hardtanh.h>
796
+ #include <ATen/ops/hardtanh_backward.h>
797
+ #include <ATen/ops/hash_tensor.h>
798
+ #include <ATen/ops/heaviside.h>
799
+ #include <ATen/ops/hinge_embedding_loss.h>
800
+ #include <ATen/ops/histc.h>
801
+ #include <ATen/ops/histogram.h>
802
+ #include <ATen/ops/histogramdd.h>
803
+ #include <ATen/ops/hsplit.h>
804
+ #include <ATen/ops/hspmm.h>
805
+ #include <ATen/ops/hstack.h>
806
+ #include <ATen/ops/huber_loss.h>
807
+ #include <ATen/ops/huber_loss_backward.h>
808
+ #include <ATen/ops/hypot.h>
809
+ #include <ATen/ops/i0.h>
810
+ #include <ATen/ops/igamma.h>
811
+ #include <ATen/ops/igammac.h>
812
+ #include <ATen/ops/im2col.h>
813
+ #include <ATen/ops/imag.h>
814
+ #include <ATen/ops/index.h>
815
+ #include <ATen/ops/index_add.h>
816
+ #include <ATen/ops/index_copy.h>
817
+ #include <ATen/ops/index_fill.h>
818
+ #include <ATen/ops/index_put.h>
819
+ #include <ATen/ops/index_reduce.h>
820
+ #include <ATen/ops/index_select.h>
821
+ #include <ATen/ops/index_select_backward.h>
822
+ #include <ATen/ops/indices.h>
823
+ #include <ATen/ops/indices_copy.h>
824
+ #include <ATen/ops/infinitely_differentiable_gelu_backward.h>
825
+ #include <ATen/ops/inner.h>
826
+ #include <ATen/ops/instance_norm.h>
827
+ #include <ATen/ops/int_repr.h>
828
+ #include <ATen/ops/inverse.h>
829
+ #include <ATen/ops/is_coalesced.h>
830
+ #include <ATen/ops/is_complex.h>
831
+ #include <ATen/ops/is_conj.h>
832
+ #include <ATen/ops/is_distributed.h>
833
+ #include <ATen/ops/is_floating_point.h>
834
+ #include <ATen/ops/is_inference.h>
835
+ #include <ATen/ops/is_leaf.h>
836
+ #include <ATen/ops/is_neg.h>
837
+ #include <ATen/ops/is_nonzero.h>
838
+ #include <ATen/ops/is_pinned.h>
839
+ #include <ATen/ops/is_same_size.h>
840
+ #include <ATen/ops/is_set_to.h>
841
+ #include <ATen/ops/is_signed.h>
842
+ #include <ATen/ops/is_vulkan_available.h>
843
+ #include <ATen/ops/isclose.h>
844
+ #include <ATen/ops/isfinite.h>
845
+ #include <ATen/ops/isin.h>
846
+ #include <ATen/ops/isinf.h>
847
+ #include <ATen/ops/isnan.h>
848
+ #include <ATen/ops/isneginf.h>
849
+ #include <ATen/ops/isposinf.h>
850
+ #include <ATen/ops/isreal.h>
851
+ #include <ATen/ops/istft.h>
852
+ #include <ATen/ops/item.h>
853
+ #include <ATen/ops/kaiser_window.h>
854
+ #include <ATen/ops/kl_div.h>
855
+ #include <ATen/ops/kron.h>
856
+ #include <ATen/ops/kthvalue.h>
857
+ #include <ATen/ops/l1_loss.h>
858
+ #include <ATen/ops/layer_norm.h>
859
+ #include <ATen/ops/lcm.h>
860
+ #include <ATen/ops/ldexp.h>
861
+ #include <ATen/ops/le.h>
862
+ #include <ATen/ops/leaky_relu.h>
863
+ #include <ATen/ops/leaky_relu_backward.h>
864
+ #include <ATen/ops/lerp.h>
865
+ #include <ATen/ops/less.h>
866
+ #include <ATen/ops/less_equal.h>
867
+ #include <ATen/ops/lgamma.h>
868
+ #include <ATen/ops/lift.h>
869
+ #include <ATen/ops/lift_fresh.h>
870
+ #include <ATen/ops/lift_fresh_copy.h>
871
+ #include <ATen/ops/linalg_cholesky.h>
872
+ #include <ATen/ops/linalg_cholesky_ex.h>
873
+ #include <ATen/ops/linalg_cond.h>
874
+ #include <ATen/ops/linalg_cross.h>
875
+ #include <ATen/ops/linalg_det.h>
876
+ #include <ATen/ops/linalg_diagonal.h>
877
+ #include <ATen/ops/linalg_eig.h>
878
+ #include <ATen/ops/linalg_eigh.h>
879
+ #include <ATen/ops/linalg_eigvals.h>
880
+ #include <ATen/ops/linalg_eigvalsh.h>
881
+ #include <ATen/ops/linalg_householder_product.h>
882
+ #include <ATen/ops/linalg_inv.h>
883
+ #include <ATen/ops/linalg_inv_ex.h>
884
+ #include <ATen/ops/linalg_ldl_factor.h>
885
+ #include <ATen/ops/linalg_ldl_factor_ex.h>
886
+ #include <ATen/ops/linalg_ldl_solve.h>
887
+ #include <ATen/ops/linalg_lstsq.h>
888
+ #include <ATen/ops/linalg_lu.h>
889
+ #include <ATen/ops/linalg_lu_factor.h>
890
+ #include <ATen/ops/linalg_lu_factor_ex.h>
891
+ #include <ATen/ops/linalg_lu_solve.h>
892
+ #include <ATen/ops/linalg_matmul.h>
893
+ #include <ATen/ops/linalg_matrix_exp.h>
894
+ #include <ATen/ops/linalg_matrix_norm.h>
895
+ #include <ATen/ops/linalg_matrix_power.h>
896
+ #include <ATen/ops/linalg_matrix_rank.h>
897
+ #include <ATen/ops/linalg_multi_dot.h>
898
+ #include <ATen/ops/linalg_norm.h>
899
+ #include <ATen/ops/linalg_pinv.h>
900
+ #include <ATen/ops/linalg_qr.h>
901
+ #include <ATen/ops/linalg_slogdet.h>
902
+ #include <ATen/ops/linalg_solve.h>
903
+ #include <ATen/ops/linalg_solve_ex.h>
904
+ #include <ATen/ops/linalg_solve_triangular.h>
905
+ #include <ATen/ops/linalg_svd.h>
906
+ #include <ATen/ops/linalg_svdvals.h>
907
+ #include <ATen/ops/linalg_tensorinv.h>
908
+ #include <ATen/ops/linalg_tensorsolve.h>
909
+ #include <ATen/ops/linalg_vander.h>
910
+ #include <ATen/ops/linalg_vecdot.h>
911
+ #include <ATen/ops/linalg_vector_norm.h>
912
+ #include <ATen/ops/linear.h>
913
+ #include <ATen/ops/linear_backward.h>
914
+ #include <ATen/ops/linspace.h>
915
+ #include <ATen/ops/log.h>
916
+ #include <ATen/ops/log10.h>
917
+ #include <ATen/ops/log1p.h>
918
+ #include <ATen/ops/log2.h>
919
+ #include <ATen/ops/log_normal.h>
920
+ #include <ATen/ops/log_sigmoid.h>
921
+ #include <ATen/ops/log_sigmoid_backward.h>
922
+ #include <ATen/ops/log_sigmoid_forward.h>
923
+ #include <ATen/ops/log_softmax.h>
924
+ #include <ATen/ops/logaddexp.h>
925
+ #include <ATen/ops/logaddexp2.h>
926
+ #include <ATen/ops/logcumsumexp.h>
927
+ #include <ATen/ops/logdet.h>
928
+ #include <ATen/ops/logical_and.h>
929
+ #include <ATen/ops/logical_not.h>
930
+ #include <ATen/ops/logical_or.h>
931
+ #include <ATen/ops/logical_xor.h>
932
+ #include <ATen/ops/logit.h>
933
+ #include <ATen/ops/logit_backward.h>
934
+ #include <ATen/ops/logspace.h>
935
+ #include <ATen/ops/logsumexp.h>
936
+ #include <ATen/ops/lshift.h>
937
+ #include <ATen/ops/lstm.h>
938
+ #include <ATen/ops/lstm_cell.h>
939
+ #include <ATen/ops/lstm_mps_backward.h>
940
+ #include <ATen/ops/lt.h>
941
+ #include <ATen/ops/lu_solve.h>
942
+ #include <ATen/ops/lu_unpack.h>
943
+ #include <ATen/ops/mH.h>
944
+ #include <ATen/ops/mT.h>
945
+ #include <ATen/ops/margin_ranking_loss.h>
946
+ #include <ATen/ops/masked_fill.h>
947
+ #include <ATen/ops/masked_scatter.h>
948
+ #include <ATen/ops/masked_scatter_backward.h>
949
+ #include <ATen/ops/masked_select.h>
950
+ #include <ATen/ops/masked_select_backward.h>
951
+ #include <ATen/ops/matmul.h>
952
+ #include <ATen/ops/matmul_backward.h>
953
+ #include <ATen/ops/matrix_H.h>
954
+ #include <ATen/ops/matrix_exp.h>
955
+ #include <ATen/ops/matrix_exp_backward.h>
956
+ #include <ATen/ops/matrix_power.h>
957
+ #include <ATen/ops/max.h>
958
+ #include <ATen/ops/max_pool1d.h>
959
+ #include <ATen/ops/max_pool1d_with_indices.h>
960
+ #include <ATen/ops/max_pool2d.h>
961
+ #include <ATen/ops/max_pool2d_backward.h>
962
+ #include <ATen/ops/max_pool2d_with_indices.h>
963
+ #include <ATen/ops/max_pool2d_with_indices_backward.h>
964
+ #include <ATen/ops/max_pool3d.h>
965
+ #include <ATen/ops/max_pool3d_with_indices.h>
966
+ #include <ATen/ops/max_pool3d_with_indices_backward.h>
967
+ #include <ATen/ops/max_unpool2d.h>
968
+ #include <ATen/ops/max_unpool3d.h>
969
+ #include <ATen/ops/maximum.h>
970
+ #include <ATen/ops/mean.h>
971
+ #include <ATen/ops/median.h>
972
+ #include <ATen/ops/meshgrid.h>
973
+ #include <ATen/ops/min.h>
974
+ #include <ATen/ops/minimum.h>
975
+ #include <ATen/ops/miopen_batch_norm.h>
976
+ #include <ATen/ops/miopen_batch_norm_backward.h>
977
+ #include <ATen/ops/miopen_convolution.h>
978
+ #include <ATen/ops/miopen_convolution_add_relu.h>
979
+ #include <ATen/ops/miopen_convolution_relu.h>
980
+ #include <ATen/ops/miopen_convolution_transpose.h>
981
+ #include <ATen/ops/miopen_depthwise_convolution.h>
982
+ #include <ATen/ops/miopen_rnn.h>
983
+ #include <ATen/ops/miopen_rnn_backward.h>
984
+ #include <ATen/ops/mish.h>
985
+ #include <ATen/ops/mish_backward.h>
986
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
987
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
988
+ #include <ATen/ops/mkldnn_convolution.h>
989
+ #include <ATen/ops/mkldnn_linear.h>
990
+ #include <ATen/ops/mkldnn_linear_backward.h>
991
+ #include <ATen/ops/mkldnn_linear_backward_input.h>
992
+ #include <ATen/ops/mkldnn_linear_backward_weights.h>
993
+ #include <ATen/ops/mkldnn_max_pool2d.h>
994
+ #include <ATen/ops/mkldnn_max_pool2d_backward.h>
995
+ #include <ATen/ops/mkldnn_max_pool3d.h>
996
+ #include <ATen/ops/mkldnn_max_pool3d_backward.h>
997
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
998
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
999
+ #include <ATen/ops/mkldnn_rnn_layer.h>
1000
+ #include <ATen/ops/mkldnn_rnn_layer_backward.h>
1001
+ #include <ATen/ops/mm.h>
1002
+ #include <ATen/ops/mode.h>
1003
+ #include <ATen/ops/moveaxis.h>
1004
+ #include <ATen/ops/movedim.h>
1005
+ #include <ATen/ops/mps_convolution_backward.h>
1006
+ #include <ATen/ops/mps_convolution_transpose_backward.h>
1007
+ #include <ATen/ops/mse_loss.h>
1008
+ #include <ATen/ops/mse_loss_backward.h>
1009
+ #include <ATen/ops/msort.h>
1010
+ #include <ATen/ops/mul.h>
1011
+ #include <ATen/ops/multi_margin_loss.h>
1012
+ #include <ATen/ops/multi_margin_loss_backward.h>
1013
+ #include <ATen/ops/multilabel_margin_loss.h>
1014
+ #include <ATen/ops/multilabel_margin_loss_backward.h>
1015
+ #include <ATen/ops/multilabel_margin_loss_forward.h>
1016
+ #include <ATen/ops/multinomial.h>
1017
+ #include <ATen/ops/multiply.h>
1018
+ #include <ATen/ops/mv.h>
1019
+ #include <ATen/ops/mvlgamma.h>
1020
+ #include <ATen/ops/nan_to_num.h>
1021
+ #include <ATen/ops/nanmean.h>
1022
+ #include <ATen/ops/nanmedian.h>
1023
+ #include <ATen/ops/nanquantile.h>
1024
+ #include <ATen/ops/nansum.h>
1025
+ #include <ATen/ops/narrow.h>
1026
+ #include <ATen/ops/narrow_copy.h>
1027
+ #include <ATen/ops/native_batch_norm.h>
1028
+ #include <ATen/ops/native_batch_norm_backward.h>
1029
+ #include <ATen/ops/native_channel_shuffle.h>
1030
+ #include <ATen/ops/native_dropout.h>
1031
+ #include <ATen/ops/native_dropout_backward.h>
1032
+ #include <ATen/ops/native_group_norm.h>
1033
+ #include <ATen/ops/native_group_norm_backward.h>
1034
+ #include <ATen/ops/native_layer_norm.h>
1035
+ #include <ATen/ops/native_layer_norm_backward.h>
1036
+ #include <ATen/ops/native_norm.h>
1037
+ #include <ATen/ops/ne.h>
1038
+ #include <ATen/ops/neg.h>
1039
+ #include <ATen/ops/negative.h>
1040
+ #include <ATen/ops/nested_to_padded_tensor.h>
1041
+ #include <ATen/ops/new_empty.h>
1042
+ #include <ATen/ops/new_empty_strided.h>
1043
+ #include <ATen/ops/new_full.h>
1044
+ #include <ATen/ops/new_ones.h>
1045
+ #include <ATen/ops/new_zeros.h>
1046
+ #include <ATen/ops/nextafter.h>
1047
+ #include <ATen/ops/nll_loss.h>
1048
+ #include <ATen/ops/nll_loss2d.h>
1049
+ #include <ATen/ops/nll_loss2d_backward.h>
1050
+ #include <ATen/ops/nll_loss2d_forward.h>
1051
+ #include <ATen/ops/nll_loss_backward.h>
1052
+ #include <ATen/ops/nll_loss_forward.h>
1053
+ #include <ATen/ops/nll_loss_nd.h>
1054
+ #include <ATen/ops/nonzero.h>
1055
+ #include <ATen/ops/nonzero_numpy.h>
1056
+ #include <ATen/ops/nonzero_static.h>
1057
+ #include <ATen/ops/norm.h>
1058
+ #include <ATen/ops/norm_except_dim.h>
1059
+ #include <ATen/ops/normal.h>
1060
+ #include <ATen/ops/not_equal.h>
1061
+ #include <ATen/ops/nuclear_norm.h>
1062
+ #include <ATen/ops/numpy_T.h>
1063
+ #include <ATen/ops/one_hot.h>
1064
+ #include <ATen/ops/ones.h>
1065
+ #include <ATen/ops/ones_like.h>
1066
+ #include <ATen/ops/or.h>
1067
+ #include <ATen/ops/orgqr.h>
1068
+ #include <ATen/ops/ormqr.h>
1069
+ #include <ATen/ops/outer.h>
1070
+ #include <ATen/ops/output_nr.h>
1071
+ #include <ATen/ops/pad.h>
1072
+ #include <ATen/ops/pad_sequence.h>
1073
+ #include <ATen/ops/pairwise_distance.h>
1074
+ #include <ATen/ops/pdist.h>
1075
+ #include <ATen/ops/permute.h>
1076
+ #include <ATen/ops/permute_copy.h>
1077
+ #include <ATen/ops/pin_memory.h>
1078
+ #include <ATen/ops/pinverse.h>
1079
+ #include <ATen/ops/pixel_shuffle.h>
1080
+ #include <ATen/ops/pixel_unshuffle.h>
1081
+ #include <ATen/ops/poisson.h>
1082
+ #include <ATen/ops/poisson_nll_loss.h>
1083
+ #include <ATen/ops/polar.h>
1084
+ #include <ATen/ops/polygamma.h>
1085
+ #include <ATen/ops/positive.h>
1086
+ #include <ATen/ops/pow.h>
1087
+ #include <ATen/ops/prelu.h>
1088
+ #include <ATen/ops/prod.h>
1089
+ #include <ATen/ops/promote_types.h>
1090
+ #include <ATen/ops/put.h>
1091
+ #include <ATen/ops/q_per_channel_axis.h>
1092
+ #include <ATen/ops/q_per_channel_scales.h>
1093
+ #include <ATen/ops/q_per_channel_zero_points.h>
1094
+ #include <ATen/ops/q_scale.h>
1095
+ #include <ATen/ops/q_zero_point.h>
1096
+ #include <ATen/ops/qr.h>
1097
+ #include <ATen/ops/qscheme.h>
1098
+ #include <ATen/ops/quantile.h>
1099
+ #include <ATen/ops/quantize_per_channel.h>
1100
+ #include <ATen/ops/quantize_per_tensor.h>
1101
+ #include <ATen/ops/quantize_per_tensor_dynamic.h>
1102
+ #include <ATen/ops/quantized_batch_norm.h>
1103
+ #include <ATen/ops/quantized_gru_cell.h>
1104
+ #include <ATen/ops/quantized_lstm_cell.h>
1105
+ #include <ATen/ops/quantized_max_pool1d.h>
1106
+ #include <ATen/ops/quantized_max_pool2d.h>
1107
+ #include <ATen/ops/quantized_max_pool3d.h>
1108
+ #include <ATen/ops/quantized_rnn_relu_cell.h>
1109
+ #include <ATen/ops/quantized_rnn_tanh_cell.h>
1110
+ #include <ATen/ops/rad2deg.h>
1111
+ #include <ATen/ops/rand.h>
1112
+ #include <ATen/ops/rand_like.h>
1113
+ #include <ATen/ops/randint.h>
1114
+ #include <ATen/ops/randint_like.h>
1115
+ #include <ATen/ops/randn.h>
1116
+ #include <ATen/ops/randn_like.h>
1117
+ #include <ATen/ops/random.h>
1118
+ #include <ATen/ops/randperm.h>
1119
+ #include <ATen/ops/range.h>
1120
+ #include <ATen/ops/ravel.h>
1121
+ #include <ATen/ops/real.h>
1122
+ #include <ATen/ops/reciprocal.h>
1123
+ #include <ATen/ops/record_stream.h>
1124
+ #include <ATen/ops/refine_names.h>
1125
+ #include <ATen/ops/reflection_pad1d.h>
1126
+ #include <ATen/ops/reflection_pad1d_backward.h>
1127
+ #include <ATen/ops/reflection_pad2d.h>
1128
+ #include <ATen/ops/reflection_pad2d_backward.h>
1129
+ #include <ATen/ops/reflection_pad3d.h>
1130
+ #include <ATen/ops/reflection_pad3d_backward.h>
1131
+ #include <ATen/ops/relu.h>
1132
+ #include <ATen/ops/relu6.h>
1133
+ #include <ATen/ops/remainder.h>
1134
+ #include <ATen/ops/rename.h>
1135
+ #include <ATen/ops/renorm.h>
1136
+ #include <ATen/ops/repeat.h>
1137
+ #include <ATen/ops/repeat_interleave.h>
1138
+ #include <ATen/ops/replication_pad1d.h>
1139
+ #include <ATen/ops/replication_pad1d_backward.h>
1140
+ #include <ATen/ops/replication_pad2d.h>
1141
+ #include <ATen/ops/replication_pad2d_backward.h>
1142
+ #include <ATen/ops/replication_pad3d.h>
1143
+ #include <ATen/ops/replication_pad3d_backward.h>
1144
+ #include <ATen/ops/requires_grad.h>
1145
+ #include <ATen/ops/reshape.h>
1146
+ #include <ATen/ops/reshape_as.h>
1147
+ #include <ATen/ops/resize.h>
1148
+ #include <ATen/ops/resize_as.h>
1149
+ #include <ATen/ops/resize_as_sparse.h>
1150
+ #include <ATen/ops/resolve_conj.h>
1151
+ #include <ATen/ops/resolve_neg.h>
1152
+ #include <ATen/ops/result_type.h>
1153
+ #include <ATen/ops/retain_grad.h>
1154
+ #include <ATen/ops/retains_grad.h>
1155
+ #include <ATen/ops/rms_norm.h>
1156
+ #include <ATen/ops/rnn_relu.h>
1157
+ #include <ATen/ops/rnn_relu_cell.h>
1158
+ #include <ATen/ops/rnn_tanh.h>
1159
+ #include <ATen/ops/rnn_tanh_cell.h>
1160
+ #include <ATen/ops/roll.h>
1161
+ #include <ATen/ops/rot90.h>
1162
+ #include <ATen/ops/round.h>
1163
+ #include <ATen/ops/row_indices.h>
1164
+ #include <ATen/ops/row_indices_copy.h>
1165
+ #include <ATen/ops/row_stack.h>
1166
+ #include <ATen/ops/rrelu.h>
1167
+ #include <ATen/ops/rrelu_with_noise.h>
1168
+ #include <ATen/ops/rrelu_with_noise_backward.h>
1169
+ #include <ATen/ops/rshift.h>
1170
+ #include <ATen/ops/rsqrt.h>
1171
+ #include <ATen/ops/rsub.h>
1172
+ #include <ATen/ops/scalar_tensor.h>
1173
+ #include <ATen/ops/scaled_dot_product_attention.h>
1174
+ #include <ATen/ops/scatter.h>
1175
+ #include <ATen/ops/scatter_add.h>
1176
+ #include <ATen/ops/scatter_reduce.h>
1177
+ #include <ATen/ops/searchsorted.h>
1178
+ #include <ATen/ops/segment_reduce.h>
1179
+ #include <ATen/ops/select.h>
1180
+ #include <ATen/ops/select_backward.h>
1181
+ #include <ATen/ops/select_copy.h>
1182
+ #include <ATen/ops/select_scatter.h>
1183
+ #include <ATen/ops/selu.h>
1184
+ #include <ATen/ops/set.h>
1185
+ #include <ATen/ops/set_data.h>
1186
+ #include <ATen/ops/sgn.h>
1187
+ #include <ATen/ops/sigmoid.h>
1188
+ #include <ATen/ops/sigmoid_backward.h>
1189
+ #include <ATen/ops/sign.h>
1190
+ #include <ATen/ops/signbit.h>
1191
+ #include <ATen/ops/silu.h>
1192
+ #include <ATen/ops/silu_backward.h>
1193
+ #include <ATen/ops/sin.h>
1194
+ #include <ATen/ops/sinc.h>
1195
+ #include <ATen/ops/sinh.h>
1196
+ #include <ATen/ops/size.h>
1197
+ #include <ATen/ops/slice.h>
1198
+ #include <ATen/ops/slice_backward.h>
1199
+ #include <ATen/ops/slice_copy.h>
1200
+ #include <ATen/ops/slice_inverse.h>
1201
+ #include <ATen/ops/slice_scatter.h>
1202
+ #include <ATen/ops/slogdet.h>
1203
+ #include <ATen/ops/slow_conv3d.h>
1204
+ #include <ATen/ops/slow_conv3d_forward.h>
1205
+ #include <ATen/ops/slow_conv_dilated2d.h>
1206
+ #include <ATen/ops/slow_conv_dilated3d.h>
1207
+ #include <ATen/ops/slow_conv_transpose2d.h>
1208
+ #include <ATen/ops/slow_conv_transpose3d.h>
1209
+ #include <ATen/ops/smm.h>
1210
+ #include <ATen/ops/smooth_l1_loss.h>
1211
+ #include <ATen/ops/smooth_l1_loss_backward.h>
1212
+ #include <ATen/ops/soft_margin_loss.h>
1213
+ #include <ATen/ops/soft_margin_loss_backward.h>
1214
+ #include <ATen/ops/softmax.h>
1215
+ #include <ATen/ops/softplus.h>
1216
+ #include <ATen/ops/softplus_backward.h>
1217
+ #include <ATen/ops/softshrink.h>
1218
+ #include <ATen/ops/softshrink_backward.h>
1219
+ #include <ATen/ops/sort.h>
1220
+ #include <ATen/ops/sparse_bsc_tensor.h>
1221
+ #include <ATen/ops/sparse_bsr_tensor.h>
1222
+ #include <ATen/ops/sparse_compressed_tensor.h>
1223
+ #include <ATen/ops/sparse_coo_tensor.h>
1224
+ #include <ATen/ops/sparse_csc_tensor.h>
1225
+ #include <ATen/ops/sparse_csr_tensor.h>
1226
+ #include <ATen/ops/sparse_dim.h>
1227
+ #include <ATen/ops/sparse_mask.h>
1228
+ #include <ATen/ops/sparse_resize.h>
1229
+ #include <ATen/ops/sparse_resize_and_clear.h>
1230
+ #include <ATen/ops/sparse_sampled_addmm.h>
1231
+ #include <ATen/ops/special_airy_ai.h>
1232
+ #include <ATen/ops/special_bessel_j0.h>
1233
+ #include <ATen/ops/special_bessel_j1.h>
1234
+ #include <ATen/ops/special_bessel_y0.h>
1235
+ #include <ATen/ops/special_bessel_y1.h>
1236
+ #include <ATen/ops/special_chebyshev_polynomial_t.h>
1237
+ #include <ATen/ops/special_chebyshev_polynomial_u.h>
1238
+ #include <ATen/ops/special_chebyshev_polynomial_v.h>
1239
+ #include <ATen/ops/special_chebyshev_polynomial_w.h>
1240
+ #include <ATen/ops/special_digamma.h>
1241
+ #include <ATen/ops/special_entr.h>
1242
+ #include <ATen/ops/special_erf.h>
1243
+ #include <ATen/ops/special_erfc.h>
1244
+ #include <ATen/ops/special_erfcx.h>
1245
+ #include <ATen/ops/special_erfinv.h>
1246
+ #include <ATen/ops/special_exp2.h>
1247
+ #include <ATen/ops/special_expit.h>
1248
+ #include <ATen/ops/special_expm1.h>
1249
+ #include <ATen/ops/special_gammainc.h>
1250
+ #include <ATen/ops/special_gammaincc.h>
1251
+ #include <ATen/ops/special_gammaln.h>
1252
+ #include <ATen/ops/special_hermite_polynomial_h.h>
1253
+ #include <ATen/ops/special_hermite_polynomial_he.h>
1254
+ #include <ATen/ops/special_i0.h>
1255
+ #include <ATen/ops/special_i0e.h>
1256
+ #include <ATen/ops/special_i1.h>
1257
+ #include <ATen/ops/special_i1e.h>
1258
+ #include <ATen/ops/special_laguerre_polynomial_l.h>
1259
+ #include <ATen/ops/special_legendre_polynomial_p.h>
1260
+ #include <ATen/ops/special_log1p.h>
1261
+ #include <ATen/ops/special_log_ndtr.h>
1262
+ #include <ATen/ops/special_log_softmax.h>
1263
+ #include <ATen/ops/special_logit.h>
1264
+ #include <ATen/ops/special_logsumexp.h>
1265
+ #include <ATen/ops/special_modified_bessel_i0.h>
1266
+ #include <ATen/ops/special_modified_bessel_i1.h>
1267
+ #include <ATen/ops/special_modified_bessel_k0.h>
1268
+ #include <ATen/ops/special_modified_bessel_k1.h>
1269
+ #include <ATen/ops/special_multigammaln.h>
1270
+ #include <ATen/ops/special_ndtr.h>
1271
+ #include <ATen/ops/special_ndtri.h>
1272
+ #include <ATen/ops/special_polygamma.h>
1273
+ #include <ATen/ops/special_psi.h>
1274
+ #include <ATen/ops/special_round.h>
1275
+ #include <ATen/ops/special_scaled_modified_bessel_k0.h>
1276
+ #include <ATen/ops/special_scaled_modified_bessel_k1.h>
1277
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
1278
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
1279
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
1280
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
1281
+ #include <ATen/ops/special_sinc.h>
1282
+ #include <ATen/ops/special_softmax.h>
1283
+ #include <ATen/ops/special_spherical_bessel_j0.h>
1284
+ #include <ATen/ops/special_xlog1py.h>
1285
+ #include <ATen/ops/special_xlogy.h>
1286
+ #include <ATen/ops/special_zeta.h>
1287
+ #include <ATen/ops/split.h>
1288
+ #include <ATen/ops/split_copy.h>
1289
+ #include <ATen/ops/split_with_sizes.h>
1290
+ #include <ATen/ops/split_with_sizes_copy.h>
1291
+ #include <ATen/ops/sqrt.h>
1292
+ #include <ATen/ops/square.h>
1293
+ #include <ATen/ops/squeeze.h>
1294
+ #include <ATen/ops/squeeze_copy.h>
1295
+ #include <ATen/ops/sspaddmm.h>
1296
+ #include <ATen/ops/stack.h>
1297
+ #include <ATen/ops/std.h>
1298
+ #include <ATen/ops/std_mean.h>
1299
+ #include <ATen/ops/stft.h>
1300
+ #include <ATen/ops/stride.h>
1301
+ #include <ATen/ops/sub.h>
1302
+ #include <ATen/ops/subtract.h>
1303
+ #include <ATen/ops/sum.h>
1304
+ #include <ATen/ops/sum_to_size.h>
1305
+ #include <ATen/ops/svd.h>
1306
+ #include <ATen/ops/swapaxes.h>
1307
+ #include <ATen/ops/swapdims.h>
1308
+ #include <ATen/ops/sym_constrain_range.h>
1309
+ #include <ATen/ops/sym_constrain_range_for_size.h>
1310
+ #include <ATen/ops/sym_is_contiguous.h>
1311
+ #include <ATen/ops/sym_numel.h>
1312
+ #include <ATen/ops/sym_size.h>
1313
+ #include <ATen/ops/sym_storage_offset.h>
1314
+ #include <ATen/ops/sym_stride.h>
1315
+ #include <ATen/ops/t.h>
1316
+ #include <ATen/ops/t_copy.h>
1317
+ #include <ATen/ops/take.h>
1318
+ #include <ATen/ops/take_along_dim.h>
1319
+ #include <ATen/ops/tan.h>
1320
+ #include <ATen/ops/tanh.h>
1321
+ #include <ATen/ops/tanh_backward.h>
1322
+ #include <ATen/ops/tensor_split.h>
1323
+ #include <ATen/ops/tensordot.h>
1324
+ #include <ATen/ops/thnn_conv2d.h>
1325
+ #include <ATen/ops/threshold.h>
1326
+ #include <ATen/ops/threshold_backward.h>
1327
+ #include <ATen/ops/tile.h>
1328
+ #include <ATen/ops/to.h>
1329
+ #include <ATen/ops/to_dense.h>
1330
+ #include <ATen/ops/to_dense_backward.h>
1331
+ #include <ATen/ops/to_mkldnn.h>
1332
+ #include <ATen/ops/to_mkldnn_backward.h>
1333
+ #include <ATen/ops/to_padded_tensor.h>
1334
+ #include <ATen/ops/to_sparse.h>
1335
+ #include <ATen/ops/to_sparse_bsc.h>
1336
+ #include <ATen/ops/to_sparse_bsr.h>
1337
+ #include <ATen/ops/to_sparse_csc.h>
1338
+ #include <ATen/ops/to_sparse_csr.h>
1339
+ #include <ATen/ops/topk.h>
1340
+ #include <ATen/ops/trace.h>
1341
+ #include <ATen/ops/trace_backward.h>
1342
+ #include <ATen/ops/transpose.h>
1343
+ #include <ATen/ops/transpose_copy.h>
1344
+ #include <ATen/ops/trapezoid.h>
1345
+ #include <ATen/ops/trapz.h>
1346
+ #include <ATen/ops/triangular_solve.h>
1347
+ #include <ATen/ops/tril.h>
1348
+ #include <ATen/ops/tril_indices.h>
1349
+ #include <ATen/ops/triplet_margin_loss.h>
1350
+ #include <ATen/ops/triu.h>
1351
+ #include <ATen/ops/triu_indices.h>
1352
+ #include <ATen/ops/true_divide.h>
1353
+ #include <ATen/ops/trunc.h>
1354
+ #include <ATen/ops/type_as.h>
1355
+ #include <ATen/ops/unbind.h>
1356
+ #include <ATen/ops/unbind_copy.h>
1357
+ #include <ATen/ops/unflatten.h>
1358
+ #include <ATen/ops/unflatten_dense_tensors.h>
1359
+ #include <ATen/ops/unfold.h>
1360
+ #include <ATen/ops/unfold_backward.h>
1361
+ #include <ATen/ops/unfold_copy.h>
1362
+ #include <ATen/ops/uniform.h>
1363
+ #include <ATen/ops/unique_consecutive.h>
1364
+ #include <ATen/ops/unique_dim.h>
1365
+ #include <ATen/ops/unique_dim_consecutive.h>
1366
+ #include <ATen/ops/unsafe_chunk.h>
1367
+ #include <ATen/ops/unsafe_split.h>
1368
+ #include <ATen/ops/unsafe_split_with_sizes.h>
1369
+ #include <ATen/ops/unsqueeze.h>
1370
+ #include <ATen/ops/unsqueeze_copy.h>
1371
+ #include <ATen/ops/upsample_bicubic2d.h>
1372
+ #include <ATen/ops/upsample_bicubic2d_backward.h>
1373
+ #include <ATen/ops/upsample_bilinear2d.h>
1374
+ #include <ATen/ops/upsample_bilinear2d_backward.h>
1375
+ #include <ATen/ops/upsample_linear1d.h>
1376
+ #include <ATen/ops/upsample_linear1d_backward.h>
1377
+ #include <ATen/ops/upsample_nearest1d.h>
1378
+ #include <ATen/ops/upsample_nearest1d_backward.h>
1379
+ #include <ATen/ops/upsample_nearest2d.h>
1380
+ #include <ATen/ops/upsample_nearest2d_backward.h>
1381
+ #include <ATen/ops/upsample_nearest3d.h>
1382
+ #include <ATen/ops/upsample_nearest3d_backward.h>
1383
+ #include <ATen/ops/upsample_trilinear3d.h>
1384
+ #include <ATen/ops/upsample_trilinear3d_backward.h>
1385
+ #include <ATen/ops/value_selecting_reduction_backward.h>
1386
+ #include <ATen/ops/values.h>
1387
+ #include <ATen/ops/values_copy.h>
1388
+ #include <ATen/ops/vander.h>
1389
+ #include <ATen/ops/var.h>
1390
+ #include <ATen/ops/var_mean.h>
1391
+ #include <ATen/ops/vdot.h>
1392
+ #include <ATen/ops/view.h>
1393
+ #include <ATen/ops/view_as.h>
1394
+ #include <ATen/ops/view_as_complex.h>
1395
+ #include <ATen/ops/view_as_complex_copy.h>
1396
+ #include <ATen/ops/view_as_real.h>
1397
+ #include <ATen/ops/view_as_real_copy.h>
1398
+ #include <ATen/ops/view_copy.h>
1399
+ #include <ATen/ops/vsplit.h>
1400
+ #include <ATen/ops/vstack.h>
1401
+ #include <ATen/ops/where.h>
1402
+ #include <ATen/ops/xlogy.h>
1403
+ #include <ATen/ops/xor.h>
1404
+ #include <ATen/ops/zero.h>
1405
+ #include <ATen/ops/zeros.h>
1406
+ #include <ATen/ops/zeros_like.h>
1407
+
1408
+ namespace at {
1409
+
1410
+
1411
+
1412
+ // Special C++ only overloads for std()-like functions (See gh-40287)
1413
+ // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
1414
+ // So, for example std(0) would select the std(unbiased=False) overload
1415
+ inline Tensor var(const Tensor& self, int dim) {
1416
+ return at::var(self, IntArrayRef{dim});
1417
+ }
1418
+ inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
1419
+ return at::var_mean(self, IntArrayRef{dim});
1420
+ }
1421
+ inline Tensor std(const Tensor& self, int dim) {
1422
+ return at::std(self, IntArrayRef{dim});
1423
+ }
1424
+ inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
1425
+ return at::std_mean(self, IntArrayRef{dim});
1426
+ }
1427
+
1428
+ inline int64_t numel(const Tensor& tensor) {
1429
+ return tensor.numel();
1430
+ }
1431
+
1432
+ inline int64_t size(const Tensor& tensor, int64_t dim) {
1433
+ return tensor.size(dim);
1434
+ }
1435
+
1436
+ inline int64_t stride(const Tensor& tensor, int64_t dim) {
1437
+ return tensor.stride(dim);
1438
+ }
1439
+
1440
+ inline bool is_complex(const Tensor& tensor) {
1441
+ return tensor.is_complex();
1442
+ }
1443
+
1444
+ inline bool is_floating_point(const Tensor& tensor) {
1445
+ return tensor.is_floating_point();
1446
+ }
1447
+
1448
+ inline bool is_signed(const Tensor& tensor) {
1449
+ return tensor.is_signed();
1450
+ }
1451
+
1452
+ inline bool is_inference(const Tensor& tensor) {
1453
+ return tensor.is_inference();
1454
+ }
1455
+
1456
+ inline bool _is_zerotensor(const Tensor& tensor) {
1457
+ return tensor._is_zerotensor();
1458
+ }
1459
+
1460
+ inline bool is_conj(const Tensor& tensor) {
1461
+ return tensor.is_conj();
1462
+ }
1463
+
1464
+ inline Tensor conj(const Tensor& tensor) {
1465
+ return tensor.conj();
1466
+ }
1467
+
1468
+ inline bool is_neg(const Tensor& tensor) {
1469
+ return tensor.is_neg();
1470
+ }
1471
+
1472
+ }
1473
+
1474
+ #else
1475
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1476
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/TensorOptions.h>
5
+
6
+ namespace at {
7
+
8
+ // Represents the initial TensorOptions, before the "defaults" are ever changed.
9
+ // This is designed to be used in library code, where the explicit devices,
10
+ // dtypes, etc. are known. NOTE: this is not a stable API.
11
+ inline TensorOptions initialTensorOptions() {
12
+ return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
13
+ false);
14
+ }
15
+
16
+ } // namespace at
17
+
18
+ #else
19
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
20
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <bitset>
5
+
6
+ #include <ATen/ArrayRef.h>
7
+ #include <ATen/SmallVector.h>
8
+ #include <ATen/Tensor.h>
9
+
10
+ namespace at {
11
+
12
+ // We assume this in a few other places in the codebase,
13
+ // but there isn't a centralized definition.
14
+ constexpr int64_t kVmapMaxTensorDims = 64;
15
+
16
+ // The valid vmap levels range from [0, 64). This effectively means that we
17
+ // support a maximum of 64 nested vmaps.
18
+ constexpr int64_t kVmapNumLevels = 64;
19
+
20
+ // Store this number of elements of BatchDims on the stack. Most people will
21
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
22
+ constexpr int64_t kBatchDimsStackSize = 5;
23
+
24
+ // a BatchDim represents a "private" dimension on a Tensor created inside of
25
+ // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
26
+ // is being vmap'ed over and the `level` being an identifier for which vmap
27
+ // said dimension was created inside. The `dim` corresponds to a "physical
28
+ // dim" - it is a dimension index on the underlying physical tensor that is
29
+ // being vmapped over.
30
+ struct BatchDim {
31
+ BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
32
+ int64_t dim() const {
33
+ return dim_;
34
+ }
35
+ int64_t level() const {
36
+ return level_;
37
+ }
38
+
39
+ private:
40
+ int64_t dim_;
41
+ int64_t level_;
42
+ };
43
+
44
+ using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
45
+ using BatchDimsRef = ArrayRef<BatchDim>;
46
+
47
+ // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
48
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
49
+ // BatchedTensorImpl.
50
+ //
51
+ // The batch dimensions are treated as being "private"; they are not
52
+ // user-visible. For example, in the following Tensor,
53
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
54
+ // dimensions 0 and 1 are batch dimensions.
55
+ //
56
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
57
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
58
+ // tensor.
59
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
60
+ explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
61
+
62
+ // Returns a reference to BatchDims that represent which dimensions of this
63
+ // tensor are private.
64
+ BatchDimsRef bdims() const {
65
+ return bdims_;
66
+ }
67
+
68
+ // BatchedTensorImpl wraps a Tensor
69
+ const Tensor& value() const {
70
+ return value_;
71
+ }
72
+
73
+ // Given a public dimension index, return the dimension index in the
74
+ // underlying value() tensor. For example, if we have
75
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
76
+ // dim=2)])
77
+ // bt.actualDim(0) -> 1
78
+ // bt.actualDim(1) -> 3
79
+ // bt.actualDim(2) -> Error
80
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
81
+
82
+ // We have to override this because we opted into CustomStrides
83
+ IntArrayRef strides_custom() const override;
84
+ // Override a bunch of methods inherited from TensorImpl to return error
85
+ // messages.
86
+ c10::SymBool sym_is_contiguous_custom(
87
+ at::MemoryFormat memory_format) const override;
88
+ void set_size(int64_t dim, int64_t new_size) override;
89
+ void set_stride(int64_t dim, int64_t new_stride) override;
90
+ void set_storage_offset(int64_t storage_offset) override;
91
+ #ifdef DEBUG
92
+ bool has_storage() const override;
93
+ #endif
94
+
95
+ private:
96
+ // see NOTE: [BatchedTensorImpl levels invariant]
97
+ void checkInvariants() const;
98
+ const char* tensorimpl_type_name() const override;
99
+
100
+ Tensor value_;
101
+
102
+ // Note: [BatchedTensorImpl levels invariant]
103
+ // There is an invariant that the BatchDims must be stored in increasing
104
+ // `level` order. That is, for i < j, bdims_[i].level must be less than
105
+ // bdims_[j].level.
106
+ BatchDims bdims_;
107
+ };
108
+
109
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
110
+ // BatchedTensorImpl.
111
+ inline bool isBatchedTensor(const Tensor& tensor) {
112
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
113
+ }
114
+
115
+ // It is unsafe to call this on a Tensor that is not backed by a
116
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
117
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
118
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
119
+ }
120
+
121
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
122
+ if (!isBatchedTensor(tensor)) {
123
+ return nullptr;
124
+ }
125
+ return unsafeGetBatchedImpl(tensor);
126
+ }
127
+
128
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
129
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
130
+ BatchDimsRef bdims) {
131
+ std::bitset<kVmapMaxTensorDims> is_bdim;
132
+ for (const auto& bdim : bdims) {
133
+ is_bdim.set(bdim.dim());
134
+ }
135
+ return is_bdim;
136
+ }
137
+
138
+ // Creates a bitset for all of the levels present in `bdims`
139
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
140
+ std::bitset<kVmapNumLevels> result;
141
+ for (const auto& bdim : bdims) {
142
+ result.set(bdim.level());
143
+ }
144
+ return result;
145
+ }
146
+
147
+ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
148
+ out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
149
+ return out;
150
+ }
151
+
152
+ // Use this to construct a BatchedTensor from a regular Tensor
153
+ TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
154
+
155
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
156
+ TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
157
+
158
+ // Checks if an inplace operation on self and other is "vmap compatible".
159
+ // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
160
+ TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
161
+
162
+ } // namespace at
163
+
164
+ #else
165
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
166
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/impl/LocalDispatchKeySet.h>
5
+
6
+ namespace at::impl {
7
+
8
+ // VmapMode contains a thread local count of how many nested vmaps
9
+ // we are currently inside. That number is known as the `vmap level`.
10
+ // VmapMode is used in the implementation of the Python `torch.vmap` API.
11
+ //
12
+ // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
13
+
14
+ struct TORCH_API VmapMode {
15
+ // Returns the vmap level, aka the count of how many nested vmaps we're in.
16
+ static int64_t current_vmap_level();
17
+
18
+ // Increment the count of nested vmaps. If this causes the vmap level to be
19
+ // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
20
+ static int64_t increment_nesting();
21
+
22
+ // Decrements the count of nested vmaps. If this causes the vmap level to be
23
+ // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
24
+ static int64_t decrement_nesting();
25
+ };
26
+
27
+ } // namespace at::impl
28
+
29
+ #else
30
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
31
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/LegacyBatchedTensorImpl.h>
5
+ #include <ATen/core/IListRef.h>
6
+
7
+ namespace at {
8
+
9
+ // This file contains abstractions used for transforming *logical* vmap
10
+ // arguments into *physical* arguments. (Keep reading for definitions of these
11
+ // terms).
12
+
13
+ // NOTE: [Logical vs physical args]
14
+ // Consider the following vmap.
15
+ // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
16
+ // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
17
+ // with batch dims 0 and 2:
18
+ // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
19
+ //
20
+ // We say the *logical* view of the tensor has size [3] -- tensors inside
21
+ // `func` appear to have size [3].
22
+ // However, the *physical* underlying tensor (the one passed to vmap) has size
23
+ // [2, 3, 4].
24
+ //
25
+ // This notion of logical vs physical also extends to non-tensor arguments.
26
+ // Consider the previous tensor; let's assume the user called
27
+ // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
28
+ // dimension they are reducing over is dim 0 but the physical dim is dim 1
29
+ // (the first non-batch dimension)
30
+
31
+ // Forward declared; see NOTE: [What is a VmapPhysicalView?]
32
+ struct VmapPhysicalView;
33
+
34
+ // Most PyTorch operators take 4 or fewer inputs.
35
+ constexpr int64_t kVmapTransformStaticInputSize = 4;
36
+ using VmapPhysicalViewVec =
37
+ SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
38
+
39
+ // Pytorch generally advertises good performance for <= 5 dims.
40
+ // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
41
+ // dimensions to get 8. Adjust this number as necessary
42
+ constexpr int64_t kVmapStaticDimVecSize = 8;
43
+ using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
44
+ using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
45
+
46
+ // NOTE: [What is an VmapTransform?]
47
+ // An *VmapTransform* converts logical views of tensors to physical views.
48
+ //
49
+ // Batching rules use VmapTransforms to convert logical arguments to
50
+ // physical arguments, then call one or more at:: operator that handles the
51
+ // physical arguments, and then converts the physical result back to a logical
52
+ // argument.
53
+
54
+ // VmapTransform for operators that take tensors with multiple batch dims.
55
+ // Given one or more logical views on Tensors, `logicalToPhysical`
56
+ // permutes all of the batch dims to the front of the tensor, aligns
57
+ // and expands the batch dims to match each other (according to their `level`),
58
+ // and returns a VmapPhysicalView on the tensor(s).
59
+ struct TORCH_API MultiBatchVmapTransform {
60
+ static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
61
+ static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
62
+ };
63
+
64
+ // VmapTransform for operators that broadcast all inputs.
65
+ // Given some logical views on Tensors, `logicalToPhysical`:
66
+ // - permutes all of the batch dims to the front of the tensors
67
+ // - aligns all the batch dims to the collective levels of all of the tensors.
68
+ // If a tensor does not have a batch dim for a vmap level, then it receives
69
+ // a size-one dimension for said level.
70
+ // - aligns the non-batch dims to have the same dimensionality, adding extra
71
+ // size-1 dimensions in between the batch dimensions and the non-batch
72
+ // dimensions so that the batch dimensions are lined up from the right.
73
+ //
74
+ // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
75
+ // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
76
+ // tensors of size (B, 1, 2) and (B, 3, 2).
77
+ //
78
+ // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
79
+ // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
80
+ // actually *need* to return a tensor of size (1, 2) for the second tensor
81
+ // because the broadcasting operation takes care of that for us, but we do
82
+ // it anyways to keep things simple.
83
+ struct TORCH_API BroadcastingVmapTransform {
84
+ static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
85
+ };
86
+
87
+ // Forward declared, if you're reading this file head to toe, don't worry about
88
+ // it yet.
89
+ struct VmapPhysicalToLogicalMap;
90
+
91
+ // NOTE: [What is a VmapPhysicalView?]
92
+ // VmapPhysicalView represents a physical view on a Tensor.
93
+ //
94
+ // One can use it to further convert logical dimension indices, logical shapes,
95
+ // and more to their physical variants, or convert a new (physical) tensor into
96
+ // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
97
+ //
98
+ // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
99
+ // the front and some levels that correspond to said batch dimensions.
100
+ //
101
+ // The levels bitset specifies which vmap levels correspond to the batch
102
+ // dimensions at the front of the tensor. In particular, the number of set bits
103
+ // corresponds to the number of batch dimensions on `tensor` and the rightmost
104
+ // bit of `levels` specifies the maximum number of nested vmaps we are in at
105
+ // this point in time.
106
+ // For example, given:
107
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
108
+ //
109
+ // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
110
+ // than or equal to 3.
111
+ // bitset: 010100
112
+ // ^
113
+ // |
114
+ // levels: 012345
115
+ struct TORCH_API VmapPhysicalView {
116
+ VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
117
+ : levels_(levels), tensor_(std::move(tensor)) {
118
+ TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
119
+ }
120
+
121
+ Tensor& tensor() {
122
+ return tensor_;
123
+ }
124
+ const Tensor& tensor() const {
125
+ return tensor_;
126
+ }
127
+
128
+ // Maps logical dim indices to physical dim indices. Also does dim wrapping.
129
+ //
130
+ // For example, given:
131
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
132
+ //
133
+ // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
134
+ // This is because the size of levels tell us that the first two dimensions
135
+ // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
136
+ // a physical dim of `n + 2`.
137
+ VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
138
+ int64_t getPhysicalDim(int64_t logical_dim) const;
139
+
140
+ // Returns a VmapPhysicalToLogicalMap object. This can be used for
141
+ // mapping a physical tensor to a new logical tensor (BatchedTensor)
142
+ VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
143
+
144
+ // Maps a logical shape to a physical shape by prepending the batch
145
+ // sizes to the logical shape.
146
+ VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
147
+
148
+ int64_t numBatchDims() const;
149
+
150
+ private:
151
+ int64_t numLogicalDims() const;
152
+
153
+ std::bitset<kVmapNumLevels> levels_;
154
+ Tensor tensor_;
155
+ };
156
+
157
+ // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
158
+ // to a logical one (BatchedTensor). It holds some levels that are used to do
159
+ // the mapping and assumes that the batch dimensions in the physical tensor all
160
+ // occur at the front of the tensor.
161
+ struct TORCH_API VmapPhysicalToLogicalMap {
162
+ VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
163
+ : levels_(levels) {}
164
+
165
+ // Maps a physical tensor to a new logical tensor (BatchedTensor).
166
+ // Assumes that all of the "batch dimensions" are at the front
167
+ // of the physical tensor. For example, given:
168
+ // - x = rank-4 Tensor with size 2, 3, 5, 7
169
+ // - levels = (2, 4)
170
+ // Returns:
171
+ // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
172
+ Tensor apply(const Tensor& physical_tensor) const;
173
+
174
+ // Given a vector of physical tensors,
175
+ // 1. maps each tensor to a new logical tensor. Assumes that all of the
176
+ // "batch dimensions" are at the front of the physical tensors.
177
+ // 2. stores the new logical tensors back into the passed-in vector. This is
178
+ // to avoid additional dynamic allocations.
179
+ void applyInplace(std::vector<Tensor>& physical_tensors) const;
180
+
181
+ std::bitset<kVmapNumLevels> levels_;
182
+ };
183
+
184
+ } // namespace at
185
+
186
+ #else
187
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
188
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from MethodOperators.h
5
+
6
+ #ifdef TORCH_ASSERT_NO_OPERATORS
7
+ #error This change adds a dependency on native_functions.yaml, \
8
+ meaning the file will need to be re-compiled every time an operator \
9
+ is changed or added. Consider if your change would be better placed in \
10
+ another file, or if a more specific header might achieve the same goal. \
11
+ See NOTE: [Tensor vs. TensorBase]
12
+ #endif
13
+
14
+ // Forward declarations of any types needed in the operator signatures.
15
+ // We can't directly include these classes because it will cause circular include dependencies.
16
+ // This file is included by TensorBody.h, which defines the Tensor class.
17
+ #include <ATen/core/ATen_fwd.h>
18
+
19
+ #include <ATen/ops/_addmm_activation_ops.h>
20
+ #include <ATen/ops/_autocast_to_full_precision_ops.h>
21
+ #include <ATen/ops/_autocast_to_reduced_precision_ops.h>
22
+ #include <ATen/ops/_backward_ops.h>
23
+ #include <ATen/ops/_coalesced_ops.h>
24
+ #include <ATen/ops/_conj_ops.h>
25
+ #include <ATen/ops/_conj_physical_ops.h>
26
+ #include <ATen/ops/_dimI_ops.h>
27
+ #include <ATen/ops/_dimV_ops.h>
28
+ #include <ATen/ops/_fw_primal_ops.h>
29
+ #include <ATen/ops/_indices_ops.h>
30
+ #include <ATen/ops/_is_all_true_ops.h>
31
+ #include <ATen/ops/_is_any_true_ops.h>
32
+ #include <ATen/ops/_is_zerotensor_ops.h>
33
+ #include <ATen/ops/_lazy_clone_ops.h>
34
+ #include <ATen/ops/_neg_view_ops.h>
35
+ #include <ATen/ops/_nested_tensor_size_ops.h>
36
+ #include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
37
+ #include <ATen/ops/_nested_tensor_strides_ops.h>
38
+ #include <ATen/ops/_nnz_ops.h>
39
+ #include <ATen/ops/_reshape_alias_ops.h>
40
+ #include <ATen/ops/_sparse_mask_projection_ops.h>
41
+ #include <ATen/ops/_to_dense_ops.h>
42
+ #include <ATen/ops/_to_sparse_bsc_ops.h>
43
+ #include <ATen/ops/_to_sparse_bsr_ops.h>
44
+ #include <ATen/ops/_to_sparse_csc_ops.h>
45
+ #include <ATen/ops/_to_sparse_csr_ops.h>
46
+ #include <ATen/ops/_to_sparse_ops.h>
47
+ #include <ATen/ops/_values_ops.h>
48
+ #include <ATen/ops/_version_ops.h>
49
+ #include <ATen/ops/abs_ops.h>
50
+ #include <ATen/ops/absolute_ops.h>
51
+ #include <ATen/ops/acos_ops.h>
52
+ #include <ATen/ops/acosh_ops.h>
53
+ #include <ATen/ops/add_ops.h>
54
+ #include <ATen/ops/addbmm_ops.h>
55
+ #include <ATen/ops/addcdiv_ops.h>
56
+ #include <ATen/ops/addcmul_ops.h>
57
+ #include <ATen/ops/addmm_ops.h>
58
+ #include <ATen/ops/addmv_ops.h>
59
+ #include <ATen/ops/addr_ops.h>
60
+ #include <ATen/ops/adjoint_ops.h>
61
+ #include <ATen/ops/alias_ops.h>
62
+ #include <ATen/ops/align_as_ops.h>
63
+ #include <ATen/ops/align_to_ops.h>
64
+ #include <ATen/ops/all_ops.h>
65
+ #include <ATen/ops/allclose_ops.h>
66
+ #include <ATen/ops/amax_ops.h>
67
+ #include <ATen/ops/amin_ops.h>
68
+ #include <ATen/ops/aminmax_ops.h>
69
+ #include <ATen/ops/and_ops.h>
70
+ #include <ATen/ops/angle_ops.h>
71
+ #include <ATen/ops/any_ops.h>
72
+ #include <ATen/ops/arccos_ops.h>
73
+ #include <ATen/ops/arccosh_ops.h>
74
+ #include <ATen/ops/arcsin_ops.h>
75
+ #include <ATen/ops/arcsinh_ops.h>
76
+ #include <ATen/ops/arctan2_ops.h>
77
+ #include <ATen/ops/arctan_ops.h>
78
+ #include <ATen/ops/arctanh_ops.h>
79
+ #include <ATen/ops/argmax_ops.h>
80
+ #include <ATen/ops/argmin_ops.h>
81
+ #include <ATen/ops/argsort_ops.h>
82
+ #include <ATen/ops/argwhere_ops.h>
83
+ #include <ATen/ops/as_strided_ops.h>
84
+ #include <ATen/ops/as_strided_scatter_ops.h>
85
+ #include <ATen/ops/asin_ops.h>
86
+ #include <ATen/ops/asinh_ops.h>
87
+ #include <ATen/ops/atan2_ops.h>
88
+ #include <ATen/ops/atan_ops.h>
89
+ #include <ATen/ops/atanh_ops.h>
90
+ #include <ATen/ops/baddbmm_ops.h>
91
+ #include <ATen/ops/bernoulli_ops.h>
92
+ #include <ATen/ops/bincount_ops.h>
93
+ #include <ATen/ops/bitwise_and_ops.h>
94
+ #include <ATen/ops/bitwise_left_shift_ops.h>
95
+ #include <ATen/ops/bitwise_not_ops.h>
96
+ #include <ATen/ops/bitwise_or_ops.h>
97
+ #include <ATen/ops/bitwise_right_shift_ops.h>
98
+ #include <ATen/ops/bitwise_xor_ops.h>
99
+ #include <ATen/ops/bmm_ops.h>
100
+ #include <ATen/ops/broadcast_to_ops.h>
101
+ #include <ATen/ops/cauchy_ops.h>
102
+ #include <ATen/ops/ccol_indices_ops.h>
103
+ #include <ATen/ops/ceil_ops.h>
104
+ #include <ATen/ops/chalf_ops.h>
105
+ #include <ATen/ops/cholesky_inverse_ops.h>
106
+ #include <ATen/ops/cholesky_ops.h>
107
+ #include <ATen/ops/cholesky_solve_ops.h>
108
+ #include <ATen/ops/chunk_ops.h>
109
+ #include <ATen/ops/clamp_max_ops.h>
110
+ #include <ATen/ops/clamp_min_ops.h>
111
+ #include <ATen/ops/clamp_ops.h>
112
+ #include <ATen/ops/clip_ops.h>
113
+ #include <ATen/ops/clone_ops.h>
114
+ #include <ATen/ops/coalesce_ops.h>
115
+ #include <ATen/ops/col_indices_ops.h>
116
+ #include <ATen/ops/conj_ops.h>
117
+ #include <ATen/ops/conj_physical_ops.h>
118
+ #include <ATen/ops/contiguous_ops.h>
119
+ #include <ATen/ops/copy_ops.h>
120
+ #include <ATen/ops/copysign_ops.h>
121
+ #include <ATen/ops/corrcoef_ops.h>
122
+ #include <ATen/ops/cos_ops.h>
123
+ #include <ATen/ops/cosh_ops.h>
124
+ #include <ATen/ops/count_nonzero_ops.h>
125
+ #include <ATen/ops/cov_ops.h>
126
+ #include <ATen/ops/cross_ops.h>
127
+ #include <ATen/ops/crow_indices_ops.h>
128
+ #include <ATen/ops/cummax_ops.h>
129
+ #include <ATen/ops/cummin_ops.h>
130
+ #include <ATen/ops/cumprod_ops.h>
131
+ #include <ATen/ops/cumsum_ops.h>
132
+ #include <ATen/ops/data_ops.h>
133
+ #include <ATen/ops/deg2rad_ops.h>
134
+ #include <ATen/ops/dense_dim_ops.h>
135
+ #include <ATen/ops/dequantize_ops.h>
136
+ #include <ATen/ops/det_ops.h>
137
+ #include <ATen/ops/detach_ops.h>
138
+ #include <ATen/ops/diag_embed_ops.h>
139
+ #include <ATen/ops/diag_ops.h>
140
+ #include <ATen/ops/diagflat_ops.h>
141
+ #include <ATen/ops/diagonal_ops.h>
142
+ #include <ATen/ops/diagonal_scatter_ops.h>
143
+ #include <ATen/ops/diff_ops.h>
144
+ #include <ATen/ops/digamma_ops.h>
145
+ #include <ATen/ops/dist_ops.h>
146
+ #include <ATen/ops/div_ops.h>
147
+ #include <ATen/ops/divide_ops.h>
148
+ #include <ATen/ops/dot_ops.h>
149
+ #include <ATen/ops/dsplit_ops.h>
150
+ #include <ATen/ops/eq_ops.h>
151
+ #include <ATen/ops/equal_ops.h>
152
+ #include <ATen/ops/erf_ops.h>
153
+ #include <ATen/ops/erfc_ops.h>
154
+ #include <ATen/ops/erfinv_ops.h>
155
+ #include <ATen/ops/exp2_ops.h>
156
+ #include <ATen/ops/exp_ops.h>
157
+ #include <ATen/ops/expand_as_ops.h>
158
+ #include <ATen/ops/expand_ops.h>
159
+ #include <ATen/ops/expm1_ops.h>
160
+ #include <ATen/ops/exponential_ops.h>
161
+ #include <ATen/ops/fill_diagonal_ops.h>
162
+ #include <ATen/ops/fill_ops.h>
163
+ #include <ATen/ops/fix_ops.h>
164
+ #include <ATen/ops/flatten_ops.h>
165
+ #include <ATen/ops/flip_ops.h>
166
+ #include <ATen/ops/fliplr_ops.h>
167
+ #include <ATen/ops/flipud_ops.h>
168
+ #include <ATen/ops/float_power_ops.h>
169
+ #include <ATen/ops/floor_divide_ops.h>
170
+ #include <ATen/ops/floor_ops.h>
171
+ #include <ATen/ops/fmax_ops.h>
172
+ #include <ATen/ops/fmin_ops.h>
173
+ #include <ATen/ops/fmod_ops.h>
174
+ #include <ATen/ops/frac_ops.h>
175
+ #include <ATen/ops/frexp_ops.h>
176
+ #include <ATen/ops/gather_ops.h>
177
+ #include <ATen/ops/gcd_ops.h>
178
+ #include <ATen/ops/ge_ops.h>
179
+ #include <ATen/ops/geometric_ops.h>
180
+ #include <ATen/ops/geqrf_ops.h>
181
+ #include <ATen/ops/ger_ops.h>
182
+ #include <ATen/ops/greater_equal_ops.h>
183
+ #include <ATen/ops/greater_ops.h>
184
+ #include <ATen/ops/gt_ops.h>
185
+ #include <ATen/ops/hardshrink_backward_ops.h>
186
+ #include <ATen/ops/hardshrink_ops.h>
187
+ #include <ATen/ops/hash_tensor_ops.h>
188
+ #include <ATen/ops/heaviside_ops.h>
189
+ #include <ATen/ops/histc_ops.h>
190
+ #include <ATen/ops/histogram_ops.h>
191
+ #include <ATen/ops/hsplit_ops.h>
192
+ #include <ATen/ops/hypot_ops.h>
193
+ #include <ATen/ops/i0_ops.h>
194
+ #include <ATen/ops/igamma_ops.h>
195
+ #include <ATen/ops/igammac_ops.h>
196
+ #include <ATen/ops/index_add_ops.h>
197
+ #include <ATen/ops/index_copy_ops.h>
198
+ #include <ATen/ops/index_fill_ops.h>
199
+ #include <ATen/ops/index_ops.h>
200
+ #include <ATen/ops/index_put_ops.h>
201
+ #include <ATen/ops/index_reduce_ops.h>
202
+ #include <ATen/ops/index_select_ops.h>
203
+ #include <ATen/ops/indices_ops.h>
204
+ #include <ATen/ops/inner_ops.h>
205
+ #include <ATen/ops/int_repr_ops.h>
206
+ #include <ATen/ops/inverse_ops.h>
207
+ #include <ATen/ops/is_coalesced_ops.h>
208
+ #include <ATen/ops/is_complex_ops.h>
209
+ #include <ATen/ops/is_conj_ops.h>
210
+ #include <ATen/ops/is_distributed_ops.h>
211
+ #include <ATen/ops/is_floating_point_ops.h>
212
+ #include <ATen/ops/is_inference_ops.h>
213
+ #include <ATen/ops/is_leaf_ops.h>
214
+ #include <ATen/ops/is_neg_ops.h>
215
+ #include <ATen/ops/is_nonzero_ops.h>
216
+ #include <ATen/ops/is_pinned_ops.h>
217
+ #include <ATen/ops/is_same_size_ops.h>
218
+ #include <ATen/ops/is_set_to_ops.h>
219
+ #include <ATen/ops/is_signed_ops.h>
220
+ #include <ATen/ops/isclose_ops.h>
221
+ #include <ATen/ops/isfinite_ops.h>
222
+ #include <ATen/ops/isinf_ops.h>
223
+ #include <ATen/ops/isnan_ops.h>
224
+ #include <ATen/ops/isneginf_ops.h>
225
+ #include <ATen/ops/isposinf_ops.h>
226
+ #include <ATen/ops/isreal_ops.h>
227
+ #include <ATen/ops/istft_ops.h>
228
+ #include <ATen/ops/item_ops.h>
229
+ #include <ATen/ops/kron_ops.h>
230
+ #include <ATen/ops/kthvalue_ops.h>
231
+ #include <ATen/ops/lcm_ops.h>
232
+ #include <ATen/ops/ldexp_ops.h>
233
+ #include <ATen/ops/le_ops.h>
234
+ #include <ATen/ops/lerp_ops.h>
235
+ #include <ATen/ops/less_equal_ops.h>
236
+ #include <ATen/ops/less_ops.h>
237
+ #include <ATen/ops/lgamma_ops.h>
238
+ #include <ATen/ops/log10_ops.h>
239
+ #include <ATen/ops/log1p_ops.h>
240
+ #include <ATen/ops/log2_ops.h>
241
+ #include <ATen/ops/log_normal_ops.h>
242
+ #include <ATen/ops/log_ops.h>
243
+ #include <ATen/ops/log_softmax_ops.h>
244
+ #include <ATen/ops/logaddexp2_ops.h>
245
+ #include <ATen/ops/logaddexp_ops.h>
246
+ #include <ATen/ops/logcumsumexp_ops.h>
247
+ #include <ATen/ops/logdet_ops.h>
248
+ #include <ATen/ops/logical_and_ops.h>
249
+ #include <ATen/ops/logical_not_ops.h>
250
+ #include <ATen/ops/logical_or_ops.h>
251
+ #include <ATen/ops/logical_xor_ops.h>
252
+ #include <ATen/ops/logit_ops.h>
253
+ #include <ATen/ops/logsumexp_ops.h>
254
+ #include <ATen/ops/lshift_ops.h>
255
+ #include <ATen/ops/lt_ops.h>
256
+ #include <ATen/ops/lu_solve_ops.h>
257
+ #include <ATen/ops/mH_ops.h>
258
+ #include <ATen/ops/mT_ops.h>
259
+ #include <ATen/ops/masked_fill_ops.h>
260
+ #include <ATen/ops/masked_scatter_ops.h>
261
+ #include <ATen/ops/masked_select_ops.h>
262
+ #include <ATen/ops/matmul_ops.h>
263
+ #include <ATen/ops/matrix_H_ops.h>
264
+ #include <ATen/ops/matrix_exp_ops.h>
265
+ #include <ATen/ops/matrix_power_ops.h>
266
+ #include <ATen/ops/max_ops.h>
267
+ #include <ATen/ops/maximum_ops.h>
268
+ #include <ATen/ops/mean_ops.h>
269
+ #include <ATen/ops/median_ops.h>
270
+ #include <ATen/ops/min_ops.h>
271
+ #include <ATen/ops/minimum_ops.h>
272
+ #include <ATen/ops/mm_ops.h>
273
+ #include <ATen/ops/mode_ops.h>
274
+ #include <ATen/ops/moveaxis_ops.h>
275
+ #include <ATen/ops/movedim_ops.h>
276
+ #include <ATen/ops/msort_ops.h>
277
+ #include <ATen/ops/mul_ops.h>
278
+ #include <ATen/ops/multinomial_ops.h>
279
+ #include <ATen/ops/multiply_ops.h>
280
+ #include <ATen/ops/mv_ops.h>
281
+ #include <ATen/ops/mvlgamma_ops.h>
282
+ #include <ATen/ops/nan_to_num_ops.h>
283
+ #include <ATen/ops/nanmean_ops.h>
284
+ #include <ATen/ops/nanmedian_ops.h>
285
+ #include <ATen/ops/nanquantile_ops.h>
286
+ #include <ATen/ops/nansum_ops.h>
287
+ #include <ATen/ops/narrow_copy_ops.h>
288
+ #include <ATen/ops/narrow_ops.h>
289
+ #include <ATen/ops/ne_ops.h>
290
+ #include <ATen/ops/neg_ops.h>
291
+ #include <ATen/ops/negative_ops.h>
292
+ #include <ATen/ops/new_empty_ops.h>
293
+ #include <ATen/ops/new_empty_strided_ops.h>
294
+ #include <ATen/ops/new_full_ops.h>
295
+ #include <ATen/ops/new_ones_ops.h>
296
+ #include <ATen/ops/new_zeros_ops.h>
297
+ #include <ATen/ops/nextafter_ops.h>
298
+ #include <ATen/ops/nonzero_numpy_ops.h>
299
+ #include <ATen/ops/nonzero_ops.h>
300
+ #include <ATen/ops/nonzero_static_ops.h>
301
+ #include <ATen/ops/norm_ops.h>
302
+ #include <ATen/ops/normal_ops.h>
303
+ #include <ATen/ops/not_equal_ops.h>
304
+ #include <ATen/ops/numpy_T_ops.h>
305
+ #include <ATen/ops/or_ops.h>
306
+ #include <ATen/ops/orgqr_ops.h>
307
+ #include <ATen/ops/ormqr_ops.h>
308
+ #include <ATen/ops/outer_ops.h>
309
+ #include <ATen/ops/output_nr_ops.h>
310
+ #include <ATen/ops/permute_ops.h>
311
+ #include <ATen/ops/pin_memory_ops.h>
312
+ #include <ATen/ops/pinverse_ops.h>
313
+ #include <ATen/ops/polygamma_ops.h>
314
+ #include <ATen/ops/positive_ops.h>
315
+ #include <ATen/ops/pow_ops.h>
316
+ #include <ATen/ops/prelu_ops.h>
317
+ #include <ATen/ops/prod_ops.h>
318
+ #include <ATen/ops/put_ops.h>
319
+ #include <ATen/ops/q_per_channel_axis_ops.h>
320
+ #include <ATen/ops/q_per_channel_scales_ops.h>
321
+ #include <ATen/ops/q_per_channel_zero_points_ops.h>
322
+ #include <ATen/ops/q_scale_ops.h>
323
+ #include <ATen/ops/q_zero_point_ops.h>
324
+ #include <ATen/ops/qr_ops.h>
325
+ #include <ATen/ops/qscheme_ops.h>
326
+ #include <ATen/ops/quantile_ops.h>
327
+ #include <ATen/ops/rad2deg_ops.h>
328
+ #include <ATen/ops/random_ops.h>
329
+ #include <ATen/ops/ravel_ops.h>
330
+ #include <ATen/ops/reciprocal_ops.h>
331
+ #include <ATen/ops/record_stream_ops.h>
332
+ #include <ATen/ops/refine_names_ops.h>
333
+ #include <ATen/ops/relu_ops.h>
334
+ #include <ATen/ops/remainder_ops.h>
335
+ #include <ATen/ops/rename_ops.h>
336
+ #include <ATen/ops/renorm_ops.h>
337
+ #include <ATen/ops/repeat_interleave_ops.h>
338
+ #include <ATen/ops/repeat_ops.h>
339
+ #include <ATen/ops/requires_grad_ops.h>
340
+ #include <ATen/ops/reshape_as_ops.h>
341
+ #include <ATen/ops/reshape_ops.h>
342
+ #include <ATen/ops/resize_as_ops.h>
343
+ #include <ATen/ops/resize_as_sparse_ops.h>
344
+ #include <ATen/ops/resize_ops.h>
345
+ #include <ATen/ops/resolve_conj_ops.h>
346
+ #include <ATen/ops/resolve_neg_ops.h>
347
+ #include <ATen/ops/retain_grad_ops.h>
348
+ #include <ATen/ops/retains_grad_ops.h>
349
+ #include <ATen/ops/roll_ops.h>
350
+ #include <ATen/ops/rot90_ops.h>
351
+ #include <ATen/ops/round_ops.h>
352
+ #include <ATen/ops/row_indices_ops.h>
353
+ #include <ATen/ops/rshift_ops.h>
354
+ #include <ATen/ops/rsqrt_ops.h>
355
+ #include <ATen/ops/scatter_add_ops.h>
356
+ #include <ATen/ops/scatter_ops.h>
357
+ #include <ATen/ops/scatter_reduce_ops.h>
358
+ #include <ATen/ops/select_ops.h>
359
+ #include <ATen/ops/select_scatter_ops.h>
360
+ #include <ATen/ops/set_data_ops.h>
361
+ #include <ATen/ops/set_ops.h>
362
+ #include <ATen/ops/sgn_ops.h>
363
+ #include <ATen/ops/sigmoid_ops.h>
364
+ #include <ATen/ops/sign_ops.h>
365
+ #include <ATen/ops/signbit_ops.h>
366
+ #include <ATen/ops/sin_ops.h>
367
+ #include <ATen/ops/sinc_ops.h>
368
+ #include <ATen/ops/sinh_ops.h>
369
+ #include <ATen/ops/size_ops.h>
370
+ #include <ATen/ops/slice_inverse_ops.h>
371
+ #include <ATen/ops/slice_ops.h>
372
+ #include <ATen/ops/slice_scatter_ops.h>
373
+ #include <ATen/ops/slogdet_ops.h>
374
+ #include <ATen/ops/smm_ops.h>
375
+ #include <ATen/ops/softmax_ops.h>
376
+ #include <ATen/ops/sort_ops.h>
377
+ #include <ATen/ops/sparse_dim_ops.h>
378
+ #include <ATen/ops/sparse_mask_ops.h>
379
+ #include <ATen/ops/sparse_resize_and_clear_ops.h>
380
+ #include <ATen/ops/sparse_resize_ops.h>
381
+ #include <ATen/ops/split_ops.h>
382
+ #include <ATen/ops/split_with_sizes_ops.h>
383
+ #include <ATen/ops/sqrt_ops.h>
384
+ #include <ATen/ops/square_ops.h>
385
+ #include <ATen/ops/squeeze_ops.h>
386
+ #include <ATen/ops/sspaddmm_ops.h>
387
+ #include <ATen/ops/std_ops.h>
388
+ #include <ATen/ops/stft_ops.h>
389
+ #include <ATen/ops/stride_ops.h>
390
+ #include <ATen/ops/sub_ops.h>
391
+ #include <ATen/ops/subtract_ops.h>
392
+ #include <ATen/ops/sum_ops.h>
393
+ #include <ATen/ops/sum_to_size_ops.h>
394
+ #include <ATen/ops/svd_ops.h>
395
+ #include <ATen/ops/swapaxes_ops.h>
396
+ #include <ATen/ops/swapdims_ops.h>
397
+ #include <ATen/ops/t_ops.h>
398
+ #include <ATen/ops/take_along_dim_ops.h>
399
+ #include <ATen/ops/take_ops.h>
400
+ #include <ATen/ops/tan_ops.h>
401
+ #include <ATen/ops/tanh_ops.h>
402
+ #include <ATen/ops/tensor_split_ops.h>
403
+ #include <ATen/ops/tile_ops.h>
404
+ #include <ATen/ops/to_dense_ops.h>
405
+ #include <ATen/ops/to_mkldnn_ops.h>
406
+ #include <ATen/ops/to_ops.h>
407
+ #include <ATen/ops/to_padded_tensor_ops.h>
408
+ #include <ATen/ops/to_sparse_bsc_ops.h>
409
+ #include <ATen/ops/to_sparse_bsr_ops.h>
410
+ #include <ATen/ops/to_sparse_csc_ops.h>
411
+ #include <ATen/ops/to_sparse_csr_ops.h>
412
+ #include <ATen/ops/to_sparse_ops.h>
413
+ #include <ATen/ops/topk_ops.h>
414
+ #include <ATen/ops/trace_ops.h>
415
+ #include <ATen/ops/transpose_ops.h>
416
+ #include <ATen/ops/triangular_solve_ops.h>
417
+ #include <ATen/ops/tril_ops.h>
418
+ #include <ATen/ops/triu_ops.h>
419
+ #include <ATen/ops/true_divide_ops.h>
420
+ #include <ATen/ops/trunc_ops.h>
421
+ #include <ATen/ops/type_as_ops.h>
422
+ #include <ATen/ops/unbind_ops.h>
423
+ #include <ATen/ops/unflatten_ops.h>
424
+ #include <ATen/ops/unfold_ops.h>
425
+ #include <ATen/ops/uniform_ops.h>
426
+ #include <ATen/ops/unsafe_chunk_ops.h>
427
+ #include <ATen/ops/unsafe_split_ops.h>
428
+ #include <ATen/ops/unsafe_split_with_sizes_ops.h>
429
+ #include <ATen/ops/unsqueeze_ops.h>
430
+ #include <ATen/ops/values_ops.h>
431
+ #include <ATen/ops/var_ops.h>
432
+ #include <ATen/ops/vdot_ops.h>
433
+ #include <ATen/ops/view_as_ops.h>
434
+ #include <ATen/ops/view_ops.h>
435
+ #include <ATen/ops/vsplit_ops.h>
436
+ #include <ATen/ops/where_ops.h>
437
+ #include <ATen/ops/xlogy_ops.h>
438
+ #include <ATen/ops/xor_ops.h>
439
+ #include <ATen/ops/zero_ops.h>
440
+
441
+ namespace at {
442
+ namespace _ops {
443
+
444
+ } // namespace _ops
445
+ } // namespace at
446
+
447
+ #else
448
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
449
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <ATen/core/NamedTensor.h>
3
+
4
+ #else
5
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
6
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from NativeMetaFunctions.h
5
+
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/core/IListRef.h>
8
+ #include <ATen/TensorMeta.h>
9
+ #include <ATen/TensorIterator.h>
10
+
11
+ #include <ATen/ops/_adaptive_avg_pool2d_meta.h>
12
+ #include <ATen/ops/_adaptive_avg_pool2d_backward_meta.h>
13
+ #include <ATen/ops/_adaptive_avg_pool3d_meta.h>
14
+ #include <ATen/ops/_adaptive_avg_pool3d_backward_meta.h>
15
+ #include <ATen/ops/_add_batch_dim_meta.h>
16
+ #include <ATen/ops/_add_relu_meta.h>
17
+ #include <ATen/ops/_addmm_activation_meta.h>
18
+ #include <ATen/ops/_aminmax_meta.h>
19
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_meta.h>
20
+ #include <ATen/ops/_amp_update_scale_meta.h>
21
+ #include <ATen/ops/_assert_async_meta.h>
22
+ #include <ATen/ops/_assert_scalar_meta.h>
23
+ #include <ATen/ops/_assert_tensor_metadata_meta.h>
24
+ #include <ATen/ops/_autocast_to_full_precision_meta.h>
25
+ #include <ATen/ops/_autocast_to_reduced_precision_meta.h>
26
+ #include <ATen/ops/_backward_meta.h>
27
+ #include <ATen/ops/_batch_norm_impl_index_meta.h>
28
+ #include <ATen/ops/_batch_norm_impl_index_backward_meta.h>
29
+ #include <ATen/ops/_batch_norm_no_update_meta.h>
30
+ #include <ATen/ops/_batch_norm_with_update_meta.h>
31
+ #include <ATen/ops/_cast_Byte_meta.h>
32
+ #include <ATen/ops/_cast_Char_meta.h>
33
+ #include <ATen/ops/_cast_Double_meta.h>
34
+ #include <ATen/ops/_cast_Float_meta.h>
35
+ #include <ATen/ops/_cast_Half_meta.h>
36
+ #include <ATen/ops/_cast_Int_meta.h>
37
+ #include <ATen/ops/_cast_Long_meta.h>
38
+ #include <ATen/ops/_cast_Short_meta.h>
39
+ #include <ATen/ops/_cdist_backward_meta.h>
40
+ #include <ATen/ops/_cdist_forward_meta.h>
41
+ #include <ATen/ops/_cholesky_solve_helper_meta.h>
42
+ #include <ATen/ops/_choose_qparams_per_tensor_meta.h>
43
+ #include <ATen/ops/_chunk_cat_meta.h>
44
+ #include <ATen/ops/_coalesce_meta.h>
45
+ #include <ATen/ops/_coalesced_meta.h>
46
+ #include <ATen/ops/_compute_linear_combination_meta.h>
47
+ #include <ATen/ops/_conj_meta.h>
48
+ #include <ATen/ops/_conj_copy_meta.h>
49
+ #include <ATen/ops/_conj_physical_meta.h>
50
+ #include <ATen/ops/_conv_depthwise2d_meta.h>
51
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_meta.h>
52
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_meta.h>
53
+ #include <ATen/ops/_convert_weight_to_int4pack_meta.h>
54
+ #include <ATen/ops/_convert_weight_to_int4pack_for_cpu_meta.h>
55
+ #include <ATen/ops/_convolution_meta.h>
56
+ #include <ATen/ops/_convolution_double_backward_meta.h>
57
+ #include <ATen/ops/_convolution_mode_meta.h>
58
+ #include <ATen/ops/_copy_from_meta.h>
59
+ #include <ATen/ops/_copy_from_and_resize_meta.h>
60
+ #include <ATen/ops/_cslt_compress_meta.h>
61
+ #include <ATen/ops/_cslt_sparse_mm_meta.h>
62
+ #include <ATen/ops/_cslt_sparse_mm_search_meta.h>
63
+ #include <ATen/ops/_ctc_loss_meta.h>
64
+ #include <ATen/ops/_ctc_loss_backward_meta.h>
65
+ #include <ATen/ops/_cudnn_attention_backward_meta.h>
66
+ #include <ATen/ops/_cudnn_attention_forward_meta.h>
67
+ #include <ATen/ops/_cudnn_ctc_loss_meta.h>
68
+ #include <ATen/ops/_cudnn_init_dropout_state_meta.h>
69
+ #include <ATen/ops/_cudnn_rnn_meta.h>
70
+ #include <ATen/ops/_cudnn_rnn_backward_meta.h>
71
+ #include <ATen/ops/_cudnn_rnn_flatten_weight_meta.h>
72
+ #include <ATen/ops/_cufft_clear_plan_cache_meta.h>
73
+ #include <ATen/ops/_cufft_get_plan_cache_max_size_meta.h>
74
+ #include <ATen/ops/_cufft_get_plan_cache_size_meta.h>
75
+ #include <ATen/ops/_cufft_set_plan_cache_max_size_meta.h>
76
+ #include <ATen/ops/_cummax_helper_meta.h>
77
+ #include <ATen/ops/_cummin_helper_meta.h>
78
+ #include <ATen/ops/_debug_has_internal_overlap_meta.h>
79
+ #include <ATen/ops/_dimI_meta.h>
80
+ #include <ATen/ops/_dimV_meta.h>
81
+ #include <ATen/ops/_dim_arange_meta.h>
82
+ #include <ATen/ops/_dirichlet_grad_meta.h>
83
+ #include <ATen/ops/_dyn_quant_matmul_4bit_meta.h>
84
+ #include <ATen/ops/_dyn_quant_pack_4bit_weight_meta.h>
85
+ #include <ATen/ops/_efficient_attention_backward_meta.h>
86
+ #include <ATen/ops/_efficient_attention_forward_meta.h>
87
+ #include <ATen/ops/_efficientzerotensor_meta.h>
88
+ #include <ATen/ops/_embedding_bag_meta.h>
89
+ #include <ATen/ops/_embedding_bag_backward_meta.h>
90
+ #include <ATen/ops/_embedding_bag_dense_backward_meta.h>
91
+ #include <ATen/ops/_embedding_bag_forward_only_meta.h>
92
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward_meta.h>
93
+ #include <ATen/ops/_embedding_bag_sparse_backward_meta.h>
94
+ #include <ATen/ops/_empty_affine_quantized_meta.h>
95
+ #include <ATen/ops/_empty_per_channel_affine_quantized_meta.h>
96
+ #include <ATen/ops/_euclidean_dist_meta.h>
97
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_meta.h>
98
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_meta.h>
99
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_meta.h>
100
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_meta.h>
101
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_meta.h>
102
+ #include <ATen/ops/_fft_c2c_meta.h>
103
+ #include <ATen/ops/_fft_c2r_meta.h>
104
+ #include <ATen/ops/_fft_r2c_meta.h>
105
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_meta.h>
106
+ #include <ATen/ops/_flash_attention_backward_meta.h>
107
+ #include <ATen/ops/_flash_attention_forward_meta.h>
108
+ #include <ATen/ops/_foobar_meta.h>
109
+ #include <ATen/ops/_foreach_abs_meta.h>
110
+ #include <ATen/ops/_foreach_acos_meta.h>
111
+ #include <ATen/ops/_foreach_add_meta.h>
112
+ #include <ATen/ops/_foreach_addcdiv_meta.h>
113
+ #include <ATen/ops/_foreach_addcmul_meta.h>
114
+ #include <ATen/ops/_foreach_asin_meta.h>
115
+ #include <ATen/ops/_foreach_atan_meta.h>
116
+ #include <ATen/ops/_foreach_ceil_meta.h>
117
+ #include <ATen/ops/_foreach_clamp_max_meta.h>
118
+ #include <ATen/ops/_foreach_clamp_min_meta.h>
119
+ #include <ATen/ops/_foreach_copy_meta.h>
120
+ #include <ATen/ops/_foreach_cos_meta.h>
121
+ #include <ATen/ops/_foreach_cosh_meta.h>
122
+ #include <ATen/ops/_foreach_div_meta.h>
123
+ #include <ATen/ops/_foreach_erf_meta.h>
124
+ #include <ATen/ops/_foreach_erfc_meta.h>
125
+ #include <ATen/ops/_foreach_exp_meta.h>
126
+ #include <ATen/ops/_foreach_expm1_meta.h>
127
+ #include <ATen/ops/_foreach_floor_meta.h>
128
+ #include <ATen/ops/_foreach_frac_meta.h>
129
+ #include <ATen/ops/_foreach_lerp_meta.h>
130
+ #include <ATen/ops/_foreach_lgamma_meta.h>
131
+ #include <ATen/ops/_foreach_log_meta.h>
132
+ #include <ATen/ops/_foreach_log10_meta.h>
133
+ #include <ATen/ops/_foreach_log1p_meta.h>
134
+ #include <ATen/ops/_foreach_log2_meta.h>
135
+ #include <ATen/ops/_foreach_max_meta.h>
136
+ #include <ATen/ops/_foreach_maximum_meta.h>
137
+ #include <ATen/ops/_foreach_minimum_meta.h>
138
+ #include <ATen/ops/_foreach_mul_meta.h>
139
+ #include <ATen/ops/_foreach_neg_meta.h>
140
+ #include <ATen/ops/_foreach_norm_meta.h>
141
+ #include <ATen/ops/_foreach_pow_meta.h>
142
+ #include <ATen/ops/_foreach_reciprocal_meta.h>
143
+ #include <ATen/ops/_foreach_round_meta.h>
144
+ #include <ATen/ops/_foreach_rsqrt_meta.h>
145
+ #include <ATen/ops/_foreach_sigmoid_meta.h>
146
+ #include <ATen/ops/_foreach_sign_meta.h>
147
+ #include <ATen/ops/_foreach_sin_meta.h>
148
+ #include <ATen/ops/_foreach_sinh_meta.h>
149
+ #include <ATen/ops/_foreach_sqrt_meta.h>
150
+ #include <ATen/ops/_foreach_sub_meta.h>
151
+ #include <ATen/ops/_foreach_tan_meta.h>
152
+ #include <ATen/ops/_foreach_tanh_meta.h>
153
+ #include <ATen/ops/_foreach_trunc_meta.h>
154
+ #include <ATen/ops/_foreach_zero_meta.h>
155
+ #include <ATen/ops/_functional_assert_async_meta.h>
156
+ #include <ATen/ops/_functional_assert_scalar_meta.h>
157
+ #include <ATen/ops/_functional_sym_constrain_range_meta.h>
158
+ #include <ATen/ops/_functional_sym_constrain_range_for_size_meta.h>
159
+ #include <ATen/ops/_fused_adagrad_meta.h>
160
+ #include <ATen/ops/_fused_adam_meta.h>
161
+ #include <ATen/ops/_fused_adamw_meta.h>
162
+ #include <ATen/ops/_fused_dropout_meta.h>
163
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper_meta.h>
164
+ #include <ATen/ops/_fused_rms_norm_meta.h>
165
+ #include <ATen/ops/_fused_rms_norm_backward_meta.h>
166
+ #include <ATen/ops/_fused_sdp_choice_meta.h>
167
+ #include <ATen/ops/_fused_sgd_meta.h>
168
+ #include <ATen/ops/_fw_primal_meta.h>
169
+ #include <ATen/ops/_fw_primal_copy_meta.h>
170
+ #include <ATen/ops/_gather_sparse_backward_meta.h>
171
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_meta.h>
172
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_meta.h>
173
+ #include <ATen/ops/_grouped_mm_meta.h>
174
+ #include <ATen/ops/_has_compatible_shallow_copy_type_meta.h>
175
+ #include <ATen/ops/_has_same_storage_numel_meta.h>
176
+ #include <ATen/ops/_histogramdd_bin_edges_meta.h>
177
+ #include <ATen/ops/_histogramdd_from_bin_cts_meta.h>
178
+ #include <ATen/ops/_histogramdd_from_bin_tensors_meta.h>
179
+ #include <ATen/ops/_index_put_impl_meta.h>
180
+ #include <ATen/ops/_indices_meta.h>
181
+ #include <ATen/ops/_indices_copy_meta.h>
182
+ #include <ATen/ops/_int_mm_meta.h>
183
+ #include <ATen/ops/_is_all_true_meta.h>
184
+ #include <ATen/ops/_is_any_true_meta.h>
185
+ #include <ATen/ops/_is_zerotensor_meta.h>
186
+ #include <ATen/ops/_jagged_to_padded_dense_forward_meta.h>
187
+ #include <ATen/ops/_lazy_clone_meta.h>
188
+ #include <ATen/ops/_linalg_check_errors_meta.h>
189
+ #include <ATen/ops/_linalg_det_meta.h>
190
+ #include <ATen/ops/_linalg_eigh_meta.h>
191
+ #include <ATen/ops/_linalg_eigvals_meta.h>
192
+ #include <ATen/ops/_linalg_slogdet_meta.h>
193
+ #include <ATen/ops/_linalg_solve_ex_meta.h>
194
+ #include <ATen/ops/_linalg_svd_meta.h>
195
+ #include <ATen/ops/_local_scalar_dense_meta.h>
196
+ #include <ATen/ops/_log_softmax_meta.h>
197
+ #include <ATen/ops/_log_softmax_backward_data_meta.h>
198
+ #include <ATen/ops/_logcumsumexp_meta.h>
199
+ #include <ATen/ops/_lstm_mps_meta.h>
200
+ #include <ATen/ops/_lu_with_info_meta.h>
201
+ #include <ATen/ops/_make_dep_token_meta.h>
202
+ #include <ATen/ops/_make_dual_meta.h>
203
+ #include <ATen/ops/_make_dual_copy_meta.h>
204
+ #include <ATen/ops/_make_per_channel_quantized_tensor_meta.h>
205
+ #include <ATen/ops/_make_per_tensor_quantized_tensor_meta.h>
206
+ #include <ATen/ops/_masked_scale_meta.h>
207
+ #include <ATen/ops/_masked_softmax_meta.h>
208
+ #include <ATen/ops/_masked_softmax_backward_meta.h>
209
+ #include <ATen/ops/_mixed_dtypes_linear_meta.h>
210
+ #include <ATen/ops/_mkldnn_reshape_meta.h>
211
+ #include <ATen/ops/_mkldnn_transpose_meta.h>
212
+ #include <ATen/ops/_mps_convolution_meta.h>
213
+ #include <ATen/ops/_mps_convolution_transpose_meta.h>
214
+ #include <ATen/ops/_native_batch_norm_legit_meta.h>
215
+ #include <ATen/ops/_native_batch_norm_legit_no_training_meta.h>
216
+ #include <ATen/ops/_native_multi_head_attention_meta.h>
217
+ #include <ATen/ops/_neg_view_meta.h>
218
+ #include <ATen/ops/_neg_view_copy_meta.h>
219
+ #include <ATen/ops/_nested_compute_contiguous_strides_offsets_meta.h>
220
+ #include <ATen/ops/_nested_from_padded_meta.h>
221
+ #include <ATen/ops/_nested_from_padded_and_nested_example_meta.h>
222
+ #include <ATen/ops/_nested_from_padded_tensor_meta.h>
223
+ #include <ATen/ops/_nested_get_jagged_dummy_meta.h>
224
+ #include <ATen/ops/_nested_get_lengths_meta.h>
225
+ #include <ATen/ops/_nested_get_max_seqlen_meta.h>
226
+ #include <ATen/ops/_nested_get_min_seqlen_meta.h>
227
+ #include <ATen/ops/_nested_get_offsets_meta.h>
228
+ #include <ATen/ops/_nested_get_ragged_idx_meta.h>
229
+ #include <ATen/ops/_nested_get_values_meta.h>
230
+ #include <ATen/ops/_nested_get_values_copy_meta.h>
231
+ #include <ATen/ops/_nested_select_backward_meta.h>
232
+ #include <ATen/ops/_nested_sum_backward_meta.h>
233
+ #include <ATen/ops/_nested_tensor_from_mask_meta.h>
234
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned_meta.h>
235
+ #include <ATen/ops/_nested_tensor_from_tensor_list_meta.h>
236
+ #include <ATen/ops/_nested_tensor_size_meta.h>
237
+ #include <ATen/ops/_nested_tensor_softmax_with_shape_meta.h>
238
+ #include <ATen/ops/_nested_tensor_storage_offsets_meta.h>
239
+ #include <ATen/ops/_nested_tensor_strides_meta.h>
240
+ #include <ATen/ops/_nested_view_from_buffer_meta.h>
241
+ #include <ATen/ops/_nested_view_from_buffer_copy_meta.h>
242
+ #include <ATen/ops/_nested_view_from_jagged_meta.h>
243
+ #include <ATen/ops/_nested_view_from_jagged_copy_meta.h>
244
+ #include <ATen/ops/_new_zeros_with_same_feature_meta_meta.h>
245
+ #include <ATen/ops/_nnpack_available_meta.h>
246
+ #include <ATen/ops/_nnpack_spatial_convolution_meta.h>
247
+ #include <ATen/ops/_nnz_meta.h>
248
+ #include <ATen/ops/_pack_padded_sequence_meta.h>
249
+ #include <ATen/ops/_pack_padded_sequence_backward_meta.h>
250
+ #include <ATen/ops/_pad_circular_meta.h>
251
+ #include <ATen/ops/_pad_enum_meta.h>
252
+ #include <ATen/ops/_pad_packed_sequence_meta.h>
253
+ #include <ATen/ops/_padded_dense_to_jagged_forward_meta.h>
254
+ #include <ATen/ops/_pdist_backward_meta.h>
255
+ #include <ATen/ops/_pdist_forward_meta.h>
256
+ #include <ATen/ops/_pin_memory_meta.h>
257
+ #include <ATen/ops/_prelu_kernel_meta.h>
258
+ #include <ATen/ops/_prelu_kernel_backward_meta.h>
259
+ #include <ATen/ops/_print_meta.h>
260
+ #include <ATen/ops/_propagate_xla_data_meta.h>
261
+ #include <ATen/ops/_remove_batch_dim_meta.h>
262
+ #include <ATen/ops/_reshape_alias_meta.h>
263
+ #include <ATen/ops/_reshape_alias_copy_meta.h>
264
+ #include <ATen/ops/_reshape_copy_meta.h>
265
+ #include <ATen/ops/_reshape_from_tensor_meta.h>
266
+ #include <ATen/ops/_resize_output_meta.h>
267
+ #include <ATen/ops/_rowwise_prune_meta.h>
268
+ #include <ATen/ops/_safe_softmax_meta.h>
269
+ #include <ATen/ops/_sample_dirichlet_meta.h>
270
+ #include <ATen/ops/_saturate_weight_to_fp16_meta.h>
271
+ #include <ATen/ops/_scaled_dot_product_attention_math_meta.h>
272
+ #include <ATen/ops/_scaled_dot_product_attention_math_for_mps_meta.h>
273
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_meta.h>
274
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_meta.h>
275
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_meta.h>
276
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward_meta.h>
277
+ #include <ATen/ops/_scaled_dot_product_flash_attention_meta.h>
278
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward_meta.h>
279
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_meta.h>
280
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_meta.h>
281
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_meta.h>
282
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_meta.h>
283
+ #include <ATen/ops/_scaled_grouped_mm_meta.h>
284
+ #include <ATen/ops/_scaled_grouped_mm_v2_meta.h>
285
+ #include <ATen/ops/_scaled_mm_meta.h>
286
+ #include <ATen/ops/_scaled_mm_v2_meta.h>
287
+ #include <ATen/ops/_segment_reduce_backward_meta.h>
288
+ #include <ATen/ops/_shape_as_tensor_meta.h>
289
+ #include <ATen/ops/_slow_conv2d_backward_meta.h>
290
+ #include <ATen/ops/_slow_conv2d_forward_meta.h>
291
+ #include <ATen/ops/_sobol_engine_draw_meta.h>
292
+ #include <ATen/ops/_sobol_engine_ff_meta.h>
293
+ #include <ATen/ops/_sobol_engine_initialize_state_meta.h>
294
+ #include <ATen/ops/_sobol_engine_scramble_meta.h>
295
+ #include <ATen/ops/_softmax_meta.h>
296
+ #include <ATen/ops/_softmax_backward_data_meta.h>
297
+ #include <ATen/ops/_sparse_addmm_meta.h>
298
+ #include <ATen/ops/_sparse_broadcast_to_meta.h>
299
+ #include <ATen/ops/_sparse_broadcast_to_copy_meta.h>
300
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe_meta.h>
301
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe_meta.h>
302
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe_meta.h>
303
+ #include <ATen/ops/_sparse_compressed_tensor_with_dims_meta.h>
304
+ #include <ATen/ops/_sparse_coo_tensor_unsafe_meta.h>
305
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_meta.h>
306
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta.h>
307
+ #include <ATen/ops/_sparse_csc_tensor_unsafe_meta.h>
308
+ #include <ATen/ops/_sparse_csr_prod_meta.h>
309
+ #include <ATen/ops/_sparse_csr_sum_meta.h>
310
+ #include <ATen/ops/_sparse_csr_tensor_unsafe_meta.h>
311
+ #include <ATen/ops/_sparse_log_softmax_meta.h>
312
+ #include <ATen/ops/_sparse_log_softmax_backward_data_meta.h>
313
+ #include <ATen/ops/_sparse_mask_projection_meta.h>
314
+ #include <ATen/ops/_sparse_mm_meta.h>
315
+ #include <ATen/ops/_sparse_mm_reduce_impl_meta.h>
316
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward_meta.h>
317
+ #include <ATen/ops/_sparse_semi_structured_addmm_meta.h>
318
+ #include <ATen/ops/_sparse_semi_structured_apply_meta.h>
319
+ #include <ATen/ops/_sparse_semi_structured_apply_dense_meta.h>
320
+ #include <ATen/ops/_sparse_semi_structured_linear_meta.h>
321
+ #include <ATen/ops/_sparse_semi_structured_mm_meta.h>
322
+ #include <ATen/ops/_sparse_semi_structured_tile_meta.h>
323
+ #include <ATen/ops/_sparse_softmax_meta.h>
324
+ #include <ATen/ops/_sparse_softmax_backward_data_meta.h>
325
+ #include <ATen/ops/_sparse_sparse_matmul_meta.h>
326
+ #include <ATen/ops/_sparse_sum_meta.h>
327
+ #include <ATen/ops/_sparse_sum_backward_meta.h>
328
+ #include <ATen/ops/_spdiags_meta.h>
329
+ #include <ATen/ops/_spsolve_meta.h>
330
+ #include <ATen/ops/_stack_meta.h>
331
+ #include <ATen/ops/_standard_gamma_meta.h>
332
+ #include <ATen/ops/_standard_gamma_grad_meta.h>
333
+ #include <ATen/ops/_test_ambiguous_defaults_meta.h>
334
+ #include <ATen/ops/_test_autograd_multiple_dispatch_meta.h>
335
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_meta.h>
336
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_meta.h>
337
+ #include <ATen/ops/_test_check_tensor_meta.h>
338
+ #include <ATen/ops/_test_functorch_fallback_meta.h>
339
+ #include <ATen/ops/_test_optional_filled_intlist_meta.h>
340
+ #include <ATen/ops/_test_optional_floatlist_meta.h>
341
+ #include <ATen/ops/_test_optional_intlist_meta.h>
342
+ #include <ATen/ops/_test_parallel_materialize_meta.h>
343
+ #include <ATen/ops/_test_serialization_subcmul_meta.h>
344
+ #include <ATen/ops/_test_string_default_meta.h>
345
+ #include <ATen/ops/_test_warn_in_autograd_meta.h>
346
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward_meta.h>
347
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward_meta.h>
348
+ #include <ATen/ops/_thnn_fused_gru_cell_meta.h>
349
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_meta.h>
350
+ #include <ATen/ops/_thnn_fused_lstm_cell_meta.h>
351
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_meta.h>
352
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_meta.h>
353
+ #include <ATen/ops/_to_copy_meta.h>
354
+ #include <ATen/ops/_to_cpu_meta.h>
355
+ #include <ATen/ops/_to_dense_meta.h>
356
+ #include <ATen/ops/_to_sparse_meta.h>
357
+ #include <ATen/ops/_to_sparse_bsc_meta.h>
358
+ #include <ATen/ops/_to_sparse_bsr_meta.h>
359
+ #include <ATen/ops/_to_sparse_csc_meta.h>
360
+ #include <ATen/ops/_to_sparse_csr_meta.h>
361
+ #include <ATen/ops/_to_sparse_semi_structured_meta.h>
362
+ #include <ATen/ops/_transform_bias_rescale_qkv_meta.h>
363
+ #include <ATen/ops/_transformer_encoder_layer_fwd_meta.h>
364
+ #include <ATen/ops/_trilinear_meta.h>
365
+ #include <ATen/ops/_triton_multi_head_attention_meta.h>
366
+ #include <ATen/ops/_triton_scaled_dot_attention_meta.h>
367
+ #include <ATen/ops/_unique_meta.h>
368
+ #include <ATen/ops/_unique2_meta.h>
369
+ #include <ATen/ops/_unpack_dual_meta.h>
370
+ #include <ATen/ops/_unsafe_index_meta.h>
371
+ #include <ATen/ops/_unsafe_index_put_meta.h>
372
+ #include <ATen/ops/_unsafe_masked_index_meta.h>
373
+ #include <ATen/ops/_unsafe_masked_index_put_accumulate_meta.h>
374
+ #include <ATen/ops/_unsafe_view_meta.h>
375
+ #include <ATen/ops/_upsample_bicubic2d_aa_meta.h>
376
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_meta.h>
377
+ #include <ATen/ops/_upsample_bilinear2d_aa_meta.h>
378
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_meta.h>
379
+ #include <ATen/ops/_upsample_nearest_exact1d_meta.h>
380
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_meta.h>
381
+ #include <ATen/ops/_upsample_nearest_exact2d_meta.h>
382
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_meta.h>
383
+ #include <ATen/ops/_upsample_nearest_exact3d_meta.h>
384
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_meta.h>
385
+ #include <ATen/ops/_use_cudnn_ctc_loss_meta.h>
386
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight_meta.h>
387
+ #include <ATen/ops/_validate_compressed_sparse_indices_meta.h>
388
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args_meta.h>
389
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args_meta.h>
390
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args_meta.h>
391
+ #include <ATen/ops/_validate_sparse_coo_tensor_args_meta.h>
392
+ #include <ATen/ops/_validate_sparse_csc_tensor_args_meta.h>
393
+ #include <ATen/ops/_validate_sparse_csr_tensor_args_meta.h>
394
+ #include <ATen/ops/_values_meta.h>
395
+ #include <ATen/ops/_values_copy_meta.h>
396
+ #include <ATen/ops/_version_meta.h>
397
+ #include <ATen/ops/_weight_int4pack_mm_meta.h>
398
+ #include <ATen/ops/_weight_int4pack_mm_for_cpu_meta.h>
399
+ #include <ATen/ops/_weight_int4pack_mm_with_scales_and_zeros_meta.h>
400
+ #include <ATen/ops/_weight_int8pack_mm_meta.h>
401
+ #include <ATen/ops/_weight_norm_meta.h>
402
+ #include <ATen/ops/_weight_norm_differentiable_backward_meta.h>
403
+ #include <ATen/ops/_weight_norm_interface_meta.h>
404
+ #include <ATen/ops/_weight_norm_interface_backward_meta.h>
405
+ #include <ATen/ops/_wrapped_linear_prepack_meta.h>
406
+ #include <ATen/ops/_wrapped_quantized_linear_prepacked_meta.h>
407
+ #include <ATen/ops/abs_meta.h>
408
+ #include <ATen/ops/absolute_meta.h>
409
+ #include <ATen/ops/acos_meta.h>
410
+ #include <ATen/ops/acosh_meta.h>
411
+ #include <ATen/ops/adaptive_avg_pool1d_meta.h>
412
+ #include <ATen/ops/adaptive_avg_pool2d_meta.h>
413
+ #include <ATen/ops/adaptive_avg_pool3d_meta.h>
414
+ #include <ATen/ops/adaptive_avg_pool3d_backward_meta.h>
415
+ #include <ATen/ops/adaptive_max_pool1d_meta.h>
416
+ #include <ATen/ops/adaptive_max_pool2d_meta.h>
417
+ #include <ATen/ops/adaptive_max_pool2d_backward_meta.h>
418
+ #include <ATen/ops/adaptive_max_pool3d_meta.h>
419
+ #include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
420
+ #include <ATen/ops/add_meta.h>
421
+ #include <ATen/ops/addbmm_meta.h>
422
+ #include <ATen/ops/addcdiv_meta.h>
423
+ #include <ATen/ops/addcmul_meta.h>
424
+ #include <ATen/ops/addmm_meta.h>
425
+ #include <ATen/ops/addmv_meta.h>
426
+ #include <ATen/ops/addr_meta.h>
427
+ #include <ATen/ops/adjoint_meta.h>
428
+ #include <ATen/ops/affine_grid_generator_meta.h>
429
+ #include <ATen/ops/affine_grid_generator_backward_meta.h>
430
+ #include <ATen/ops/alias_meta.h>
431
+ #include <ATen/ops/alias_copy_meta.h>
432
+ #include <ATen/ops/align_as_meta.h>
433
+ #include <ATen/ops/align_tensors_meta.h>
434
+ #include <ATen/ops/align_to_meta.h>
435
+ #include <ATen/ops/all_meta.h>
436
+ #include <ATen/ops/allclose_meta.h>
437
+ #include <ATen/ops/alpha_dropout_meta.h>
438
+ #include <ATen/ops/amax_meta.h>
439
+ #include <ATen/ops/amin_meta.h>
440
+ #include <ATen/ops/aminmax_meta.h>
441
+ #include <ATen/ops/and_meta.h>
442
+ #include <ATen/ops/angle_meta.h>
443
+ #include <ATen/ops/any_meta.h>
444
+ #include <ATen/ops/arange_meta.h>
445
+ #include <ATen/ops/arccos_meta.h>
446
+ #include <ATen/ops/arccosh_meta.h>
447
+ #include <ATen/ops/arcsin_meta.h>
448
+ #include <ATen/ops/arcsinh_meta.h>
449
+ #include <ATen/ops/arctan_meta.h>
450
+ #include <ATen/ops/arctan2_meta.h>
451
+ #include <ATen/ops/arctanh_meta.h>
452
+ #include <ATen/ops/argmax_meta.h>
453
+ #include <ATen/ops/argmin_meta.h>
454
+ #include <ATen/ops/argsort_meta.h>
455
+ #include <ATen/ops/argwhere_meta.h>
456
+ #include <ATen/ops/as_strided_meta.h>
457
+ #include <ATen/ops/as_strided_copy_meta.h>
458
+ #include <ATen/ops/as_strided_scatter_meta.h>
459
+ #include <ATen/ops/asin_meta.h>
460
+ #include <ATen/ops/asinh_meta.h>
461
+ #include <ATen/ops/atan_meta.h>
462
+ #include <ATen/ops/atan2_meta.h>
463
+ #include <ATen/ops/atanh_meta.h>
464
+ #include <ATen/ops/atleast_1d_meta.h>
465
+ #include <ATen/ops/atleast_2d_meta.h>
466
+ #include <ATen/ops/atleast_3d_meta.h>
467
+ #include <ATen/ops/avg_pool1d_meta.h>
468
+ #include <ATen/ops/avg_pool2d_meta.h>
469
+ #include <ATen/ops/avg_pool2d_backward_meta.h>
470
+ #include <ATen/ops/avg_pool3d_meta.h>
471
+ #include <ATen/ops/avg_pool3d_backward_meta.h>
472
+ #include <ATen/ops/baddbmm_meta.h>
473
+ #include <ATen/ops/bartlett_window_meta.h>
474
+ #include <ATen/ops/batch_norm_meta.h>
475
+ #include <ATen/ops/batch_norm_backward_meta.h>
476
+ #include <ATen/ops/batch_norm_backward_elemt_meta.h>
477
+ #include <ATen/ops/batch_norm_backward_reduce_meta.h>
478
+ #include <ATen/ops/batch_norm_elemt_meta.h>
479
+ #include <ATen/ops/batch_norm_gather_stats_meta.h>
480
+ #include <ATen/ops/batch_norm_gather_stats_with_counts_meta.h>
481
+ #include <ATen/ops/batch_norm_stats_meta.h>
482
+ #include <ATen/ops/batch_norm_update_stats_meta.h>
483
+ #include <ATen/ops/bernoulli_meta.h>
484
+ #include <ATen/ops/bilinear_meta.h>
485
+ #include <ATen/ops/binary_cross_entropy_meta.h>
486
+ #include <ATen/ops/binary_cross_entropy_backward_meta.h>
487
+ #include <ATen/ops/binary_cross_entropy_with_logits_meta.h>
488
+ #include <ATen/ops/bincount_meta.h>
489
+ #include <ATen/ops/binomial_meta.h>
490
+ #include <ATen/ops/bitwise_and_meta.h>
491
+ #include <ATen/ops/bitwise_left_shift_meta.h>
492
+ #include <ATen/ops/bitwise_not_meta.h>
493
+ #include <ATen/ops/bitwise_or_meta.h>
494
+ #include <ATen/ops/bitwise_right_shift_meta.h>
495
+ #include <ATen/ops/bitwise_xor_meta.h>
496
+ #include <ATen/ops/blackman_window_meta.h>
497
+ #include <ATen/ops/block_diag_meta.h>
498
+ #include <ATen/ops/bmm_meta.h>
499
+ #include <ATen/ops/broadcast_tensors_meta.h>
500
+ #include <ATen/ops/broadcast_to_meta.h>
501
+ #include <ATen/ops/bucketize_meta.h>
502
+ #include <ATen/ops/can_cast_meta.h>
503
+ #include <ATen/ops/cartesian_prod_meta.h>
504
+ #include <ATen/ops/cat_meta.h>
505
+ #include <ATen/ops/cauchy_meta.h>
506
+ #include <ATen/ops/ccol_indices_meta.h>
507
+ #include <ATen/ops/ccol_indices_copy_meta.h>
508
+ #include <ATen/ops/cdist_meta.h>
509
+ #include <ATen/ops/ceil_meta.h>
510
+ #include <ATen/ops/celu_meta.h>
511
+ #include <ATen/ops/chain_matmul_meta.h>
512
+ #include <ATen/ops/chalf_meta.h>
513
+ #include <ATen/ops/channel_shuffle_meta.h>
514
+ #include <ATen/ops/cholesky_meta.h>
515
+ #include <ATen/ops/cholesky_inverse_meta.h>
516
+ #include <ATen/ops/cholesky_solve_meta.h>
517
+ #include <ATen/ops/choose_qparams_optimized_meta.h>
518
+ #include <ATen/ops/chunk_meta.h>
519
+ #include <ATen/ops/clamp_meta.h>
520
+ #include <ATen/ops/clamp_max_meta.h>
521
+ #include <ATen/ops/clamp_min_meta.h>
522
+ #include <ATen/ops/clip_meta.h>
523
+ #include <ATen/ops/clone_meta.h>
524
+ #include <ATen/ops/coalesce_meta.h>
525
+ #include <ATen/ops/col2im_meta.h>
526
+ #include <ATen/ops/col_indices_meta.h>
527
+ #include <ATen/ops/col_indices_copy_meta.h>
528
+ #include <ATen/ops/column_stack_meta.h>
529
+ #include <ATen/ops/combinations_meta.h>
530
+ #include <ATen/ops/complex_meta.h>
531
+ #include <ATen/ops/concat_meta.h>
532
+ #include <ATen/ops/concatenate_meta.h>
533
+ #include <ATen/ops/conj_meta.h>
534
+ #include <ATen/ops/conj_physical_meta.h>
535
+ #include <ATen/ops/constant_pad_nd_meta.h>
536
+ #include <ATen/ops/contiguous_meta.h>
537
+ #include <ATen/ops/conv1d_meta.h>
538
+ #include <ATen/ops/conv2d_meta.h>
539
+ #include <ATen/ops/conv3d_meta.h>
540
+ #include <ATen/ops/conv_depthwise3d_meta.h>
541
+ #include <ATen/ops/conv_tbc_meta.h>
542
+ #include <ATen/ops/conv_tbc_backward_meta.h>
543
+ #include <ATen/ops/conv_transpose1d_meta.h>
544
+ #include <ATen/ops/conv_transpose2d_meta.h>
545
+ #include <ATen/ops/conv_transpose3d_meta.h>
546
+ #include <ATen/ops/convolution_meta.h>
547
+ #include <ATen/ops/convolution_backward_meta.h>
548
+ #include <ATen/ops/convolution_backward_overrideable_meta.h>
549
+ #include <ATen/ops/convolution_overrideable_meta.h>
550
+ #include <ATen/ops/copy_meta.h>
551
+ #include <ATen/ops/copy_sparse_to_sparse_meta.h>
552
+ #include <ATen/ops/copysign_meta.h>
553
+ #include <ATen/ops/corrcoef_meta.h>
554
+ #include <ATen/ops/cos_meta.h>
555
+ #include <ATen/ops/cosh_meta.h>
556
+ #include <ATen/ops/cosine_embedding_loss_meta.h>
557
+ #include <ATen/ops/cosine_similarity_meta.h>
558
+ #include <ATen/ops/count_nonzero_meta.h>
559
+ #include <ATen/ops/cov_meta.h>
560
+ #include <ATen/ops/cross_meta.h>
561
+ #include <ATen/ops/cross_entropy_loss_meta.h>
562
+ #include <ATen/ops/crow_indices_meta.h>
563
+ #include <ATen/ops/crow_indices_copy_meta.h>
564
+ #include <ATen/ops/ctc_loss_meta.h>
565
+ #include <ATen/ops/cudnn_affine_grid_generator_meta.h>
566
+ #include <ATen/ops/cudnn_affine_grid_generator_backward_meta.h>
567
+ #include <ATen/ops/cudnn_batch_norm_meta.h>
568
+ #include <ATen/ops/cudnn_batch_norm_backward_meta.h>
569
+ #include <ATen/ops/cudnn_convolution_meta.h>
570
+ #include <ATen/ops/cudnn_convolution_add_relu_meta.h>
571
+ #include <ATen/ops/cudnn_convolution_relu_meta.h>
572
+ #include <ATen/ops/cudnn_convolution_transpose_meta.h>
573
+ #include <ATen/ops/cudnn_grid_sampler_meta.h>
574
+ #include <ATen/ops/cudnn_grid_sampler_backward_meta.h>
575
+ #include <ATen/ops/cudnn_is_acceptable_meta.h>
576
+ #include <ATen/ops/cummax_meta.h>
577
+ #include <ATen/ops/cummaxmin_backward_meta.h>
578
+ #include <ATen/ops/cummin_meta.h>
579
+ #include <ATen/ops/cumprod_meta.h>
580
+ #include <ATen/ops/cumprod_backward_meta.h>
581
+ #include <ATen/ops/cumsum_meta.h>
582
+ #include <ATen/ops/cumulative_trapezoid_meta.h>
583
+ #include <ATen/ops/data_meta.h>
584
+ #include <ATen/ops/deg2rad_meta.h>
585
+ #include <ATen/ops/dense_dim_meta.h>
586
+ #include <ATen/ops/dequantize_meta.h>
587
+ #include <ATen/ops/det_meta.h>
588
+ #include <ATen/ops/detach_meta.h>
589
+ #include <ATen/ops/detach_copy_meta.h>
590
+ #include <ATen/ops/diag_meta.h>
591
+ #include <ATen/ops/diag_embed_meta.h>
592
+ #include <ATen/ops/diagflat_meta.h>
593
+ #include <ATen/ops/diagonal_meta.h>
594
+ #include <ATen/ops/diagonal_backward_meta.h>
595
+ #include <ATen/ops/diagonal_copy_meta.h>
596
+ #include <ATen/ops/diagonal_scatter_meta.h>
597
+ #include <ATen/ops/diff_meta.h>
598
+ #include <ATen/ops/digamma_meta.h>
599
+ #include <ATen/ops/dist_meta.h>
600
+ #include <ATen/ops/div_meta.h>
601
+ #include <ATen/ops/divide_meta.h>
602
+ #include <ATen/ops/dot_meta.h>
603
+ #include <ATen/ops/dropout_meta.h>
604
+ #include <ATen/ops/dsplit_meta.h>
605
+ #include <ATen/ops/dstack_meta.h>
606
+ #include <ATen/ops/einsum_meta.h>
607
+ #include <ATen/ops/elu_meta.h>
608
+ #include <ATen/ops/elu_backward_meta.h>
609
+ #include <ATen/ops/embedding_meta.h>
610
+ #include <ATen/ops/embedding_backward_meta.h>
611
+ #include <ATen/ops/embedding_bag_meta.h>
612
+ #include <ATen/ops/embedding_dense_backward_meta.h>
613
+ #include <ATen/ops/embedding_renorm_meta.h>
614
+ #include <ATen/ops/embedding_sparse_backward_meta.h>
615
+ #include <ATen/ops/empty_meta.h>
616
+ #include <ATen/ops/empty_like_meta.h>
617
+ #include <ATen/ops/empty_permuted_meta.h>
618
+ #include <ATen/ops/empty_quantized_meta.h>
619
+ #include <ATen/ops/empty_strided_meta.h>
620
+ #include <ATen/ops/eq_meta.h>
621
+ #include <ATen/ops/equal_meta.h>
622
+ #include <ATen/ops/erf_meta.h>
623
+ #include <ATen/ops/erfc_meta.h>
624
+ #include <ATen/ops/erfinv_meta.h>
625
+ #include <ATen/ops/exp_meta.h>
626
+ #include <ATen/ops/exp2_meta.h>
627
+ #include <ATen/ops/expand_meta.h>
628
+ #include <ATen/ops/expand_as_meta.h>
629
+ #include <ATen/ops/expand_copy_meta.h>
630
+ #include <ATen/ops/expm1_meta.h>
631
+ #include <ATen/ops/exponential_meta.h>
632
+ #include <ATen/ops/eye_meta.h>
633
+ #include <ATen/ops/fake_quantize_per_channel_affine_meta.h>
634
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_meta.h>
635
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_meta.h>
636
+ #include <ATen/ops/fake_quantize_per_tensor_affine_meta.h>
637
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_meta.h>
638
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_meta.h>
639
+ #include <ATen/ops/fbgemm_linear_fp16_weight_meta.h>
640
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_meta.h>
641
+ #include <ATen/ops/fbgemm_linear_int8_weight_meta.h>
642
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_meta.h>
643
+ #include <ATen/ops/fbgemm_linear_quantize_weight_meta.h>
644
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_meta.h>
645
+ #include <ATen/ops/fbgemm_pack_quantized_matrix_meta.h>
646
+ #include <ATen/ops/feature_alpha_dropout_meta.h>
647
+ #include <ATen/ops/feature_dropout_meta.h>
648
+ #include <ATen/ops/fft_fft_meta.h>
649
+ #include <ATen/ops/fft_fft2_meta.h>
650
+ #include <ATen/ops/fft_fftfreq_meta.h>
651
+ #include <ATen/ops/fft_fftn_meta.h>
652
+ #include <ATen/ops/fft_fftshift_meta.h>
653
+ #include <ATen/ops/fft_hfft_meta.h>
654
+ #include <ATen/ops/fft_hfft2_meta.h>
655
+ #include <ATen/ops/fft_hfftn_meta.h>
656
+ #include <ATen/ops/fft_ifft_meta.h>
657
+ #include <ATen/ops/fft_ifft2_meta.h>
658
+ #include <ATen/ops/fft_ifftn_meta.h>
659
+ #include <ATen/ops/fft_ifftshift_meta.h>
660
+ #include <ATen/ops/fft_ihfft_meta.h>
661
+ #include <ATen/ops/fft_ihfft2_meta.h>
662
+ #include <ATen/ops/fft_ihfftn_meta.h>
663
+ #include <ATen/ops/fft_irfft_meta.h>
664
+ #include <ATen/ops/fft_irfft2_meta.h>
665
+ #include <ATen/ops/fft_irfftn_meta.h>
666
+ #include <ATen/ops/fft_rfft_meta.h>
667
+ #include <ATen/ops/fft_rfft2_meta.h>
668
+ #include <ATen/ops/fft_rfftfreq_meta.h>
669
+ #include <ATen/ops/fft_rfftn_meta.h>
670
+ #include <ATen/ops/fill_meta.h>
671
+ #include <ATen/ops/fill_diagonal_meta.h>
672
+ #include <ATen/ops/fix_meta.h>
673
+ #include <ATen/ops/flatten_meta.h>
674
+ #include <ATen/ops/flatten_dense_tensors_meta.h>
675
+ #include <ATen/ops/flip_meta.h>
676
+ #include <ATen/ops/fliplr_meta.h>
677
+ #include <ATen/ops/flipud_meta.h>
678
+ #include <ATen/ops/float_power_meta.h>
679
+ #include <ATen/ops/floor_meta.h>
680
+ #include <ATen/ops/floor_divide_meta.h>
681
+ #include <ATen/ops/fmax_meta.h>
682
+ #include <ATen/ops/fmin_meta.h>
683
+ #include <ATen/ops/fmod_meta.h>
684
+ #include <ATen/ops/frac_meta.h>
685
+ #include <ATen/ops/fractional_max_pool2d_meta.h>
686
+ #include <ATen/ops/fractional_max_pool2d_backward_meta.h>
687
+ #include <ATen/ops/fractional_max_pool3d_meta.h>
688
+ #include <ATen/ops/fractional_max_pool3d_backward_meta.h>
689
+ #include <ATen/ops/frexp_meta.h>
690
+ #include <ATen/ops/frobenius_norm_meta.h>
691
+ #include <ATen/ops/from_file_meta.h>
692
+ #include <ATen/ops/full_meta.h>
693
+ #include <ATen/ops/full_like_meta.h>
694
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant_meta.h>
695
+ #include <ATen/ops/gather_meta.h>
696
+ #include <ATen/ops/gather_backward_meta.h>
697
+ #include <ATen/ops/gcd_meta.h>
698
+ #include <ATen/ops/ge_meta.h>
699
+ #include <ATen/ops/gelu_meta.h>
700
+ #include <ATen/ops/gelu_backward_meta.h>
701
+ #include <ATen/ops/geometric_meta.h>
702
+ #include <ATen/ops/geqrf_meta.h>
703
+ #include <ATen/ops/ger_meta.h>
704
+ #include <ATen/ops/glu_meta.h>
705
+ #include <ATen/ops/glu_backward_meta.h>
706
+ #include <ATen/ops/glu_backward_jvp_meta.h>
707
+ #include <ATen/ops/glu_jvp_meta.h>
708
+ #include <ATen/ops/gradient_meta.h>
709
+ #include <ATen/ops/greater_meta.h>
710
+ #include <ATen/ops/greater_equal_meta.h>
711
+ #include <ATen/ops/grid_sampler_meta.h>
712
+ #include <ATen/ops/grid_sampler_2d_meta.h>
713
+ #include <ATen/ops/grid_sampler_2d_backward_meta.h>
714
+ #include <ATen/ops/grid_sampler_3d_meta.h>
715
+ #include <ATen/ops/grid_sampler_3d_backward_meta.h>
716
+ #include <ATen/ops/group_norm_meta.h>
717
+ #include <ATen/ops/gru_meta.h>
718
+ #include <ATen/ops/gru_cell_meta.h>
719
+ #include <ATen/ops/gt_meta.h>
720
+ #include <ATen/ops/hamming_window_meta.h>
721
+ #include <ATen/ops/hann_window_meta.h>
722
+ #include <ATen/ops/hardshrink_meta.h>
723
+ #include <ATen/ops/hardshrink_backward_meta.h>
724
+ #include <ATen/ops/hardsigmoid_meta.h>
725
+ #include <ATen/ops/hardsigmoid_backward_meta.h>
726
+ #include <ATen/ops/hardswish_meta.h>
727
+ #include <ATen/ops/hardswish_backward_meta.h>
728
+ #include <ATen/ops/hardtanh_meta.h>
729
+ #include <ATen/ops/hardtanh_backward_meta.h>
730
+ #include <ATen/ops/hash_tensor_meta.h>
731
+ #include <ATen/ops/heaviside_meta.h>
732
+ #include <ATen/ops/hinge_embedding_loss_meta.h>
733
+ #include <ATen/ops/histc_meta.h>
734
+ #include <ATen/ops/histogram_meta.h>
735
+ #include <ATen/ops/histogramdd_meta.h>
736
+ #include <ATen/ops/hsplit_meta.h>
737
+ #include <ATen/ops/hspmm_meta.h>
738
+ #include <ATen/ops/hstack_meta.h>
739
+ #include <ATen/ops/huber_loss_meta.h>
740
+ #include <ATen/ops/huber_loss_backward_meta.h>
741
+ #include <ATen/ops/hypot_meta.h>
742
+ #include <ATen/ops/i0_meta.h>
743
+ #include <ATen/ops/igamma_meta.h>
744
+ #include <ATen/ops/igammac_meta.h>
745
+ #include <ATen/ops/im2col_meta.h>
746
+ #include <ATen/ops/imag_meta.h>
747
+ #include <ATen/ops/index_meta.h>
748
+ #include <ATen/ops/index_add_meta.h>
749
+ #include <ATen/ops/index_copy_meta.h>
750
+ #include <ATen/ops/index_fill_meta.h>
751
+ #include <ATen/ops/index_put_meta.h>
752
+ #include <ATen/ops/index_reduce_meta.h>
753
+ #include <ATen/ops/index_select_meta.h>
754
+ #include <ATen/ops/index_select_backward_meta.h>
755
+ #include <ATen/ops/indices_meta.h>
756
+ #include <ATen/ops/indices_copy_meta.h>
757
+ #include <ATen/ops/infinitely_differentiable_gelu_backward_meta.h>
758
+ #include <ATen/ops/inner_meta.h>
759
+ #include <ATen/ops/instance_norm_meta.h>
760
+ #include <ATen/ops/int_repr_meta.h>
761
+ #include <ATen/ops/inverse_meta.h>
762
+ #include <ATen/ops/is_coalesced_meta.h>
763
+ #include <ATen/ops/is_complex_meta.h>
764
+ #include <ATen/ops/is_conj_meta.h>
765
+ #include <ATen/ops/is_distributed_meta.h>
766
+ #include <ATen/ops/is_floating_point_meta.h>
767
+ #include <ATen/ops/is_inference_meta.h>
768
+ #include <ATen/ops/is_leaf_meta.h>
769
+ #include <ATen/ops/is_neg_meta.h>
770
+ #include <ATen/ops/is_nonzero_meta.h>
771
+ #include <ATen/ops/is_pinned_meta.h>
772
+ #include <ATen/ops/is_same_size_meta.h>
773
+ #include <ATen/ops/is_set_to_meta.h>
774
+ #include <ATen/ops/is_signed_meta.h>
775
+ #include <ATen/ops/is_vulkan_available_meta.h>
776
+ #include <ATen/ops/isclose_meta.h>
777
+ #include <ATen/ops/isfinite_meta.h>
778
+ #include <ATen/ops/isin_meta.h>
779
+ #include <ATen/ops/isinf_meta.h>
780
+ #include <ATen/ops/isnan_meta.h>
781
+ #include <ATen/ops/isneginf_meta.h>
782
+ #include <ATen/ops/isposinf_meta.h>
783
+ #include <ATen/ops/isreal_meta.h>
784
+ #include <ATen/ops/istft_meta.h>
785
+ #include <ATen/ops/item_meta.h>
786
+ #include <ATen/ops/kaiser_window_meta.h>
787
+ #include <ATen/ops/kl_div_meta.h>
788
+ #include <ATen/ops/kron_meta.h>
789
+ #include <ATen/ops/kthvalue_meta.h>
790
+ #include <ATen/ops/l1_loss_meta.h>
791
+ #include <ATen/ops/layer_norm_meta.h>
792
+ #include <ATen/ops/lcm_meta.h>
793
+ #include <ATen/ops/ldexp_meta.h>
794
+ #include <ATen/ops/le_meta.h>
795
+ #include <ATen/ops/leaky_relu_meta.h>
796
+ #include <ATen/ops/leaky_relu_backward_meta.h>
797
+ #include <ATen/ops/lerp_meta.h>
798
+ #include <ATen/ops/less_meta.h>
799
+ #include <ATen/ops/less_equal_meta.h>
800
+ #include <ATen/ops/lgamma_meta.h>
801
+ #include <ATen/ops/lift_meta.h>
802
+ #include <ATen/ops/lift_fresh_meta.h>
803
+ #include <ATen/ops/lift_fresh_copy_meta.h>
804
+ #include <ATen/ops/linalg_cholesky_meta.h>
805
+ #include <ATen/ops/linalg_cholesky_ex_meta.h>
806
+ #include <ATen/ops/linalg_cond_meta.h>
807
+ #include <ATen/ops/linalg_cross_meta.h>
808
+ #include <ATen/ops/linalg_det_meta.h>
809
+ #include <ATen/ops/linalg_diagonal_meta.h>
810
+ #include <ATen/ops/linalg_eig_meta.h>
811
+ #include <ATen/ops/linalg_eigh_meta.h>
812
+ #include <ATen/ops/linalg_eigvals_meta.h>
813
+ #include <ATen/ops/linalg_eigvalsh_meta.h>
814
+ #include <ATen/ops/linalg_householder_product_meta.h>
815
+ #include <ATen/ops/linalg_inv_meta.h>
816
+ #include <ATen/ops/linalg_inv_ex_meta.h>
817
+ #include <ATen/ops/linalg_ldl_factor_meta.h>
818
+ #include <ATen/ops/linalg_ldl_factor_ex_meta.h>
819
+ #include <ATen/ops/linalg_ldl_solve_meta.h>
820
+ #include <ATen/ops/linalg_lstsq_meta.h>
821
+ #include <ATen/ops/linalg_lu_meta.h>
822
+ #include <ATen/ops/linalg_lu_factor_meta.h>
823
+ #include <ATen/ops/linalg_lu_factor_ex_meta.h>
824
+ #include <ATen/ops/linalg_lu_solve_meta.h>
825
+ #include <ATen/ops/linalg_matmul_meta.h>
826
+ #include <ATen/ops/linalg_matrix_exp_meta.h>
827
+ #include <ATen/ops/linalg_matrix_norm_meta.h>
828
+ #include <ATen/ops/linalg_matrix_power_meta.h>
829
+ #include <ATen/ops/linalg_matrix_rank_meta.h>
830
+ #include <ATen/ops/linalg_multi_dot_meta.h>
831
+ #include <ATen/ops/linalg_norm_meta.h>
832
+ #include <ATen/ops/linalg_pinv_meta.h>
833
+ #include <ATen/ops/linalg_qr_meta.h>
834
+ #include <ATen/ops/linalg_slogdet_meta.h>
835
+ #include <ATen/ops/linalg_solve_meta.h>
836
+ #include <ATen/ops/linalg_solve_ex_meta.h>
837
+ #include <ATen/ops/linalg_solve_triangular_meta.h>
838
+ #include <ATen/ops/linalg_svd_meta.h>
839
+ #include <ATen/ops/linalg_svdvals_meta.h>
840
+ #include <ATen/ops/linalg_tensorinv_meta.h>
841
+ #include <ATen/ops/linalg_tensorsolve_meta.h>
842
+ #include <ATen/ops/linalg_vander_meta.h>
843
+ #include <ATen/ops/linalg_vecdot_meta.h>
844
+ #include <ATen/ops/linalg_vector_norm_meta.h>
845
+ #include <ATen/ops/linear_meta.h>
846
+ #include <ATen/ops/linear_backward_meta.h>
847
+ #include <ATen/ops/linspace_meta.h>
848
+ #include <ATen/ops/log_meta.h>
849
+ #include <ATen/ops/log10_meta.h>
850
+ #include <ATen/ops/log1p_meta.h>
851
+ #include <ATen/ops/log2_meta.h>
852
+ #include <ATen/ops/log_normal_meta.h>
853
+ #include <ATen/ops/log_sigmoid_meta.h>
854
+ #include <ATen/ops/log_sigmoid_backward_meta.h>
855
+ #include <ATen/ops/log_sigmoid_forward_meta.h>
856
+ #include <ATen/ops/log_softmax_meta.h>
857
+ #include <ATen/ops/logaddexp_meta.h>
858
+ #include <ATen/ops/logaddexp2_meta.h>
859
+ #include <ATen/ops/logcumsumexp_meta.h>
860
+ #include <ATen/ops/logdet_meta.h>
861
+ #include <ATen/ops/logical_and_meta.h>
862
+ #include <ATen/ops/logical_not_meta.h>
863
+ #include <ATen/ops/logical_or_meta.h>
864
+ #include <ATen/ops/logical_xor_meta.h>
865
+ #include <ATen/ops/logit_meta.h>
866
+ #include <ATen/ops/logit_backward_meta.h>
867
+ #include <ATen/ops/logspace_meta.h>
868
+ #include <ATen/ops/logsumexp_meta.h>
869
+ #include <ATen/ops/lshift_meta.h>
870
+ #include <ATen/ops/lstm_meta.h>
871
+ #include <ATen/ops/lstm_cell_meta.h>
872
+ #include <ATen/ops/lstm_mps_backward_meta.h>
873
+ #include <ATen/ops/lt_meta.h>
874
+ #include <ATen/ops/lu_solve_meta.h>
875
+ #include <ATen/ops/lu_unpack_meta.h>
876
+ #include <ATen/ops/mH_meta.h>
877
+ #include <ATen/ops/mT_meta.h>
878
+ #include <ATen/ops/margin_ranking_loss_meta.h>
879
+ #include <ATen/ops/masked_fill_meta.h>
880
+ #include <ATen/ops/masked_scatter_meta.h>
881
+ #include <ATen/ops/masked_scatter_backward_meta.h>
882
+ #include <ATen/ops/masked_select_meta.h>
883
+ #include <ATen/ops/masked_select_backward_meta.h>
884
+ #include <ATen/ops/matmul_meta.h>
885
+ #include <ATen/ops/matmul_backward_meta.h>
886
+ #include <ATen/ops/matrix_H_meta.h>
887
+ #include <ATen/ops/matrix_exp_meta.h>
888
+ #include <ATen/ops/matrix_exp_backward_meta.h>
889
+ #include <ATen/ops/matrix_power_meta.h>
890
+ #include <ATen/ops/max_meta.h>
891
+ #include <ATen/ops/max_pool1d_meta.h>
892
+ #include <ATen/ops/max_pool1d_with_indices_meta.h>
893
+ #include <ATen/ops/max_pool2d_meta.h>
894
+ #include <ATen/ops/max_pool2d_backward_meta.h>
895
+ #include <ATen/ops/max_pool2d_with_indices_meta.h>
896
+ #include <ATen/ops/max_pool2d_with_indices_backward_meta.h>
897
+ #include <ATen/ops/max_pool3d_meta.h>
898
+ #include <ATen/ops/max_pool3d_with_indices_meta.h>
899
+ #include <ATen/ops/max_pool3d_with_indices_backward_meta.h>
900
+ #include <ATen/ops/max_unpool2d_meta.h>
901
+ #include <ATen/ops/max_unpool3d_meta.h>
902
+ #include <ATen/ops/maximum_meta.h>
903
+ #include <ATen/ops/mean_meta.h>
904
+ #include <ATen/ops/median_meta.h>
905
+ #include <ATen/ops/meshgrid_meta.h>
906
+ #include <ATen/ops/min_meta.h>
907
+ #include <ATen/ops/minimum_meta.h>
908
+ #include <ATen/ops/miopen_batch_norm_meta.h>
909
+ #include <ATen/ops/miopen_batch_norm_backward_meta.h>
910
+ #include <ATen/ops/miopen_convolution_meta.h>
911
+ #include <ATen/ops/miopen_convolution_add_relu_meta.h>
912
+ #include <ATen/ops/miopen_convolution_relu_meta.h>
913
+ #include <ATen/ops/miopen_convolution_transpose_meta.h>
914
+ #include <ATen/ops/miopen_depthwise_convolution_meta.h>
915
+ #include <ATen/ops/miopen_rnn_meta.h>
916
+ #include <ATen/ops/miopen_rnn_backward_meta.h>
917
+ #include <ATen/ops/mish_meta.h>
918
+ #include <ATen/ops/mish_backward_meta.h>
919
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_meta.h>
920
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_meta.h>
921
+ #include <ATen/ops/mkldnn_convolution_meta.h>
922
+ #include <ATen/ops/mkldnn_linear_meta.h>
923
+ #include <ATen/ops/mkldnn_linear_backward_meta.h>
924
+ #include <ATen/ops/mkldnn_linear_backward_input_meta.h>
925
+ #include <ATen/ops/mkldnn_linear_backward_weights_meta.h>
926
+ #include <ATen/ops/mkldnn_max_pool2d_meta.h>
927
+ #include <ATen/ops/mkldnn_max_pool2d_backward_meta.h>
928
+ #include <ATen/ops/mkldnn_max_pool3d_meta.h>
929
+ #include <ATen/ops/mkldnn_max_pool3d_backward_meta.h>
930
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight_meta.h>
931
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight_meta.h>
932
+ #include <ATen/ops/mkldnn_rnn_layer_meta.h>
933
+ #include <ATen/ops/mkldnn_rnn_layer_backward_meta.h>
934
+ #include <ATen/ops/mm_meta.h>
935
+ #include <ATen/ops/mode_meta.h>
936
+ #include <ATen/ops/moveaxis_meta.h>
937
+ #include <ATen/ops/movedim_meta.h>
938
+ #include <ATen/ops/mps_convolution_backward_meta.h>
939
+ #include <ATen/ops/mps_convolution_transpose_backward_meta.h>
940
+ #include <ATen/ops/mse_loss_meta.h>
941
+ #include <ATen/ops/mse_loss_backward_meta.h>
942
+ #include <ATen/ops/msort_meta.h>
943
+ #include <ATen/ops/mul_meta.h>
944
+ #include <ATen/ops/multi_margin_loss_meta.h>
945
+ #include <ATen/ops/multi_margin_loss_backward_meta.h>
946
+ #include <ATen/ops/multilabel_margin_loss_meta.h>
947
+ #include <ATen/ops/multilabel_margin_loss_backward_meta.h>
948
+ #include <ATen/ops/multilabel_margin_loss_forward_meta.h>
949
+ #include <ATen/ops/multinomial_meta.h>
950
+ #include <ATen/ops/multiply_meta.h>
951
+ #include <ATen/ops/mv_meta.h>
952
+ #include <ATen/ops/mvlgamma_meta.h>
953
+ #include <ATen/ops/nan_to_num_meta.h>
954
+ #include <ATen/ops/nanmean_meta.h>
955
+ #include <ATen/ops/nanmedian_meta.h>
956
+ #include <ATen/ops/nanquantile_meta.h>
957
+ #include <ATen/ops/nansum_meta.h>
958
+ #include <ATen/ops/narrow_meta.h>
959
+ #include <ATen/ops/narrow_copy_meta.h>
960
+ #include <ATen/ops/native_batch_norm_meta.h>
961
+ #include <ATen/ops/native_batch_norm_backward_meta.h>
962
+ #include <ATen/ops/native_channel_shuffle_meta.h>
963
+ #include <ATen/ops/native_dropout_meta.h>
964
+ #include <ATen/ops/native_dropout_backward_meta.h>
965
+ #include <ATen/ops/native_group_norm_meta.h>
966
+ #include <ATen/ops/native_group_norm_backward_meta.h>
967
+ #include <ATen/ops/native_layer_norm_meta.h>
968
+ #include <ATen/ops/native_layer_norm_backward_meta.h>
969
+ #include <ATen/ops/native_norm_meta.h>
970
+ #include <ATen/ops/ne_meta.h>
971
+ #include <ATen/ops/neg_meta.h>
972
+ #include <ATen/ops/negative_meta.h>
973
+ #include <ATen/ops/nested_to_padded_tensor_meta.h>
974
+ #include <ATen/ops/new_empty_meta.h>
975
+ #include <ATen/ops/new_empty_strided_meta.h>
976
+ #include <ATen/ops/new_full_meta.h>
977
+ #include <ATen/ops/new_ones_meta.h>
978
+ #include <ATen/ops/new_zeros_meta.h>
979
+ #include <ATen/ops/nextafter_meta.h>
980
+ #include <ATen/ops/nll_loss_meta.h>
981
+ #include <ATen/ops/nll_loss2d_meta.h>
982
+ #include <ATen/ops/nll_loss2d_backward_meta.h>
983
+ #include <ATen/ops/nll_loss2d_forward_meta.h>
984
+ #include <ATen/ops/nll_loss_backward_meta.h>
985
+ #include <ATen/ops/nll_loss_forward_meta.h>
986
+ #include <ATen/ops/nll_loss_nd_meta.h>
987
+ #include <ATen/ops/nonzero_meta.h>
988
+ #include <ATen/ops/nonzero_numpy_meta.h>
989
+ #include <ATen/ops/nonzero_static_meta.h>
990
+ #include <ATen/ops/norm_meta.h>
991
+ #include <ATen/ops/norm_except_dim_meta.h>
992
+ #include <ATen/ops/normal_meta.h>
993
+ #include <ATen/ops/not_equal_meta.h>
994
+ #include <ATen/ops/nuclear_norm_meta.h>
995
+ #include <ATen/ops/numpy_T_meta.h>
996
+ #include <ATen/ops/one_hot_meta.h>
997
+ #include <ATen/ops/ones_meta.h>
998
+ #include <ATen/ops/ones_like_meta.h>
999
+ #include <ATen/ops/or_meta.h>
1000
+ #include <ATen/ops/orgqr_meta.h>
1001
+ #include <ATen/ops/ormqr_meta.h>
1002
+ #include <ATen/ops/outer_meta.h>
1003
+ #include <ATen/ops/output_nr_meta.h>
1004
+ #include <ATen/ops/pad_meta.h>
1005
+ #include <ATen/ops/pad_sequence_meta.h>
1006
+ #include <ATen/ops/pairwise_distance_meta.h>
1007
+ #include <ATen/ops/pdist_meta.h>
1008
+ #include <ATen/ops/permute_meta.h>
1009
+ #include <ATen/ops/permute_copy_meta.h>
1010
+ #include <ATen/ops/pin_memory_meta.h>
1011
+ #include <ATen/ops/pinverse_meta.h>
1012
+ #include <ATen/ops/pixel_shuffle_meta.h>
1013
+ #include <ATen/ops/pixel_unshuffle_meta.h>
1014
+ #include <ATen/ops/poisson_meta.h>
1015
+ #include <ATen/ops/poisson_nll_loss_meta.h>
1016
+ #include <ATen/ops/polar_meta.h>
1017
+ #include <ATen/ops/polygamma_meta.h>
1018
+ #include <ATen/ops/positive_meta.h>
1019
+ #include <ATen/ops/pow_meta.h>
1020
+ #include <ATen/ops/prelu_meta.h>
1021
+ #include <ATen/ops/prod_meta.h>
1022
+ #include <ATen/ops/promote_types_meta.h>
1023
+ #include <ATen/ops/put_meta.h>
1024
+ #include <ATen/ops/q_per_channel_axis_meta.h>
1025
+ #include <ATen/ops/q_per_channel_scales_meta.h>
1026
+ #include <ATen/ops/q_per_channel_zero_points_meta.h>
1027
+ #include <ATen/ops/q_scale_meta.h>
1028
+ #include <ATen/ops/q_zero_point_meta.h>
1029
+ #include <ATen/ops/qr_meta.h>
1030
+ #include <ATen/ops/qscheme_meta.h>
1031
+ #include <ATen/ops/quantile_meta.h>
1032
+ #include <ATen/ops/quantize_per_channel_meta.h>
1033
+ #include <ATen/ops/quantize_per_tensor_meta.h>
1034
+ #include <ATen/ops/quantize_per_tensor_dynamic_meta.h>
1035
+ #include <ATen/ops/quantized_batch_norm_meta.h>
1036
+ #include <ATen/ops/quantized_gru_cell_meta.h>
1037
+ #include <ATen/ops/quantized_lstm_cell_meta.h>
1038
+ #include <ATen/ops/quantized_max_pool1d_meta.h>
1039
+ #include <ATen/ops/quantized_max_pool2d_meta.h>
1040
+ #include <ATen/ops/quantized_max_pool3d_meta.h>
1041
+ #include <ATen/ops/quantized_rnn_relu_cell_meta.h>
1042
+ #include <ATen/ops/quantized_rnn_tanh_cell_meta.h>
1043
+ #include <ATen/ops/rad2deg_meta.h>
1044
+ #include <ATen/ops/rand_meta.h>
1045
+ #include <ATen/ops/rand_like_meta.h>
1046
+ #include <ATen/ops/randint_meta.h>
1047
+ #include <ATen/ops/randint_like_meta.h>
1048
+ #include <ATen/ops/randn_meta.h>
1049
+ #include <ATen/ops/randn_like_meta.h>
1050
+ #include <ATen/ops/random_meta.h>
1051
+ #include <ATen/ops/randperm_meta.h>
1052
+ #include <ATen/ops/range_meta.h>
1053
+ #include <ATen/ops/ravel_meta.h>
1054
+ #include <ATen/ops/real_meta.h>
1055
+ #include <ATen/ops/reciprocal_meta.h>
1056
+ #include <ATen/ops/record_stream_meta.h>
1057
+ #include <ATen/ops/refine_names_meta.h>
1058
+ #include <ATen/ops/reflection_pad1d_meta.h>
1059
+ #include <ATen/ops/reflection_pad1d_backward_meta.h>
1060
+ #include <ATen/ops/reflection_pad2d_meta.h>
1061
+ #include <ATen/ops/reflection_pad2d_backward_meta.h>
1062
+ #include <ATen/ops/reflection_pad3d_meta.h>
1063
+ #include <ATen/ops/reflection_pad3d_backward_meta.h>
1064
+ #include <ATen/ops/relu_meta.h>
1065
+ #include <ATen/ops/relu6_meta.h>
1066
+ #include <ATen/ops/remainder_meta.h>
1067
+ #include <ATen/ops/rename_meta.h>
1068
+ #include <ATen/ops/renorm_meta.h>
1069
+ #include <ATen/ops/repeat_meta.h>
1070
+ #include <ATen/ops/repeat_interleave_meta.h>
1071
+ #include <ATen/ops/replication_pad1d_meta.h>
1072
+ #include <ATen/ops/replication_pad1d_backward_meta.h>
1073
+ #include <ATen/ops/replication_pad2d_meta.h>
1074
+ #include <ATen/ops/replication_pad2d_backward_meta.h>
1075
+ #include <ATen/ops/replication_pad3d_meta.h>
1076
+ #include <ATen/ops/replication_pad3d_backward_meta.h>
1077
+ #include <ATen/ops/requires_grad_meta.h>
1078
+ #include <ATen/ops/reshape_meta.h>
1079
+ #include <ATen/ops/reshape_as_meta.h>
1080
+ #include <ATen/ops/resize_meta.h>
1081
+ #include <ATen/ops/resize_as_meta.h>
1082
+ #include <ATen/ops/resize_as_sparse_meta.h>
1083
+ #include <ATen/ops/resolve_conj_meta.h>
1084
+ #include <ATen/ops/resolve_neg_meta.h>
1085
+ #include <ATen/ops/result_type_meta.h>
1086
+ #include <ATen/ops/retain_grad_meta.h>
1087
+ #include <ATen/ops/retains_grad_meta.h>
1088
+ #include <ATen/ops/rms_norm_meta.h>
1089
+ #include <ATen/ops/rnn_relu_meta.h>
1090
+ #include <ATen/ops/rnn_relu_cell_meta.h>
1091
+ #include <ATen/ops/rnn_tanh_meta.h>
1092
+ #include <ATen/ops/rnn_tanh_cell_meta.h>
1093
+ #include <ATen/ops/roll_meta.h>
1094
+ #include <ATen/ops/rot90_meta.h>
1095
+ #include <ATen/ops/round_meta.h>
1096
+ #include <ATen/ops/row_indices_meta.h>
1097
+ #include <ATen/ops/row_indices_copy_meta.h>
1098
+ #include <ATen/ops/row_stack_meta.h>
1099
+ #include <ATen/ops/rrelu_meta.h>
1100
+ #include <ATen/ops/rrelu_with_noise_meta.h>
1101
+ #include <ATen/ops/rrelu_with_noise_backward_meta.h>
1102
+ #include <ATen/ops/rshift_meta.h>
1103
+ #include <ATen/ops/rsqrt_meta.h>
1104
+ #include <ATen/ops/rsub_meta.h>
1105
+ #include <ATen/ops/scalar_tensor_meta.h>
1106
+ #include <ATen/ops/scaled_dot_product_attention_meta.h>
1107
+ #include <ATen/ops/scatter_meta.h>
1108
+ #include <ATen/ops/scatter_add_meta.h>
1109
+ #include <ATen/ops/scatter_reduce_meta.h>
1110
+ #include <ATen/ops/searchsorted_meta.h>
1111
+ #include <ATen/ops/segment_reduce_meta.h>
1112
+ #include <ATen/ops/select_meta.h>
1113
+ #include <ATen/ops/select_backward_meta.h>
1114
+ #include <ATen/ops/select_copy_meta.h>
1115
+ #include <ATen/ops/select_scatter_meta.h>
1116
+ #include <ATen/ops/selu_meta.h>
1117
+ #include <ATen/ops/set_meta.h>
1118
+ #include <ATen/ops/set_data_meta.h>
1119
+ #include <ATen/ops/sgn_meta.h>
1120
+ #include <ATen/ops/sigmoid_meta.h>
1121
+ #include <ATen/ops/sigmoid_backward_meta.h>
1122
+ #include <ATen/ops/sign_meta.h>
1123
+ #include <ATen/ops/signbit_meta.h>
1124
+ #include <ATen/ops/silu_meta.h>
1125
+ #include <ATen/ops/silu_backward_meta.h>
1126
+ #include <ATen/ops/sin_meta.h>
1127
+ #include <ATen/ops/sinc_meta.h>
1128
+ #include <ATen/ops/sinh_meta.h>
1129
+ #include <ATen/ops/size_meta.h>
1130
+ #include <ATen/ops/slice_meta.h>
1131
+ #include <ATen/ops/slice_backward_meta.h>
1132
+ #include <ATen/ops/slice_copy_meta.h>
1133
+ #include <ATen/ops/slice_inverse_meta.h>
1134
+ #include <ATen/ops/slice_scatter_meta.h>
1135
+ #include <ATen/ops/slogdet_meta.h>
1136
+ #include <ATen/ops/slow_conv3d_meta.h>
1137
+ #include <ATen/ops/slow_conv3d_forward_meta.h>
1138
+ #include <ATen/ops/slow_conv_dilated2d_meta.h>
1139
+ #include <ATen/ops/slow_conv_dilated3d_meta.h>
1140
+ #include <ATen/ops/slow_conv_transpose2d_meta.h>
1141
+ #include <ATen/ops/slow_conv_transpose3d_meta.h>
1142
+ #include <ATen/ops/smm_meta.h>
1143
+ #include <ATen/ops/smooth_l1_loss_meta.h>
1144
+ #include <ATen/ops/smooth_l1_loss_backward_meta.h>
1145
+ #include <ATen/ops/soft_margin_loss_meta.h>
1146
+ #include <ATen/ops/soft_margin_loss_backward_meta.h>
1147
+ #include <ATen/ops/softmax_meta.h>
1148
+ #include <ATen/ops/softplus_meta.h>
1149
+ #include <ATen/ops/softplus_backward_meta.h>
1150
+ #include <ATen/ops/softshrink_meta.h>
1151
+ #include <ATen/ops/softshrink_backward_meta.h>
1152
+ #include <ATen/ops/sort_meta.h>
1153
+ #include <ATen/ops/sparse_bsc_tensor_meta.h>
1154
+ #include <ATen/ops/sparse_bsr_tensor_meta.h>
1155
+ #include <ATen/ops/sparse_compressed_tensor_meta.h>
1156
+ #include <ATen/ops/sparse_coo_tensor_meta.h>
1157
+ #include <ATen/ops/sparse_csc_tensor_meta.h>
1158
+ #include <ATen/ops/sparse_csr_tensor_meta.h>
1159
+ #include <ATen/ops/sparse_dim_meta.h>
1160
+ #include <ATen/ops/sparse_mask_meta.h>
1161
+ #include <ATen/ops/sparse_resize_meta.h>
1162
+ #include <ATen/ops/sparse_resize_and_clear_meta.h>
1163
+ #include <ATen/ops/sparse_sampled_addmm_meta.h>
1164
+ #include <ATen/ops/special_airy_ai_meta.h>
1165
+ #include <ATen/ops/special_bessel_j0_meta.h>
1166
+ #include <ATen/ops/special_bessel_j1_meta.h>
1167
+ #include <ATen/ops/special_bessel_y0_meta.h>
1168
+ #include <ATen/ops/special_bessel_y1_meta.h>
1169
+ #include <ATen/ops/special_chebyshev_polynomial_t_meta.h>
1170
+ #include <ATen/ops/special_chebyshev_polynomial_u_meta.h>
1171
+ #include <ATen/ops/special_chebyshev_polynomial_v_meta.h>
1172
+ #include <ATen/ops/special_chebyshev_polynomial_w_meta.h>
1173
+ #include <ATen/ops/special_digamma_meta.h>
1174
+ #include <ATen/ops/special_entr_meta.h>
1175
+ #include <ATen/ops/special_erf_meta.h>
1176
+ #include <ATen/ops/special_erfc_meta.h>
1177
+ #include <ATen/ops/special_erfcx_meta.h>
1178
+ #include <ATen/ops/special_erfinv_meta.h>
1179
+ #include <ATen/ops/special_exp2_meta.h>
1180
+ #include <ATen/ops/special_expit_meta.h>
1181
+ #include <ATen/ops/special_expm1_meta.h>
1182
+ #include <ATen/ops/special_gammainc_meta.h>
1183
+ #include <ATen/ops/special_gammaincc_meta.h>
1184
+ #include <ATen/ops/special_gammaln_meta.h>
1185
+ #include <ATen/ops/special_hermite_polynomial_h_meta.h>
1186
+ #include <ATen/ops/special_hermite_polynomial_he_meta.h>
1187
+ #include <ATen/ops/special_i0_meta.h>
1188
+ #include <ATen/ops/special_i0e_meta.h>
1189
+ #include <ATen/ops/special_i1_meta.h>
1190
+ #include <ATen/ops/special_i1e_meta.h>
1191
+ #include <ATen/ops/special_laguerre_polynomial_l_meta.h>
1192
+ #include <ATen/ops/special_legendre_polynomial_p_meta.h>
1193
+ #include <ATen/ops/special_log1p_meta.h>
1194
+ #include <ATen/ops/special_log_ndtr_meta.h>
1195
+ #include <ATen/ops/special_log_softmax_meta.h>
1196
+ #include <ATen/ops/special_logit_meta.h>
1197
+ #include <ATen/ops/special_logsumexp_meta.h>
1198
+ #include <ATen/ops/special_modified_bessel_i0_meta.h>
1199
+ #include <ATen/ops/special_modified_bessel_i1_meta.h>
1200
+ #include <ATen/ops/special_modified_bessel_k0_meta.h>
1201
+ #include <ATen/ops/special_modified_bessel_k1_meta.h>
1202
+ #include <ATen/ops/special_multigammaln_meta.h>
1203
+ #include <ATen/ops/special_ndtr_meta.h>
1204
+ #include <ATen/ops/special_ndtri_meta.h>
1205
+ #include <ATen/ops/special_polygamma_meta.h>
1206
+ #include <ATen/ops/special_psi_meta.h>
1207
+ #include <ATen/ops/special_round_meta.h>
1208
+ #include <ATen/ops/special_scaled_modified_bessel_k0_meta.h>
1209
+ #include <ATen/ops/special_scaled_modified_bessel_k1_meta.h>
1210
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h>
1211
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h>
1212
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h>
1213
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h>
1214
+ #include <ATen/ops/special_sinc_meta.h>
1215
+ #include <ATen/ops/special_softmax_meta.h>
1216
+ #include <ATen/ops/special_spherical_bessel_j0_meta.h>
1217
+ #include <ATen/ops/special_xlog1py_meta.h>
1218
+ #include <ATen/ops/special_xlogy_meta.h>
1219
+ #include <ATen/ops/special_zeta_meta.h>
1220
+ #include <ATen/ops/split_meta.h>
1221
+ #include <ATen/ops/split_copy_meta.h>
1222
+ #include <ATen/ops/split_with_sizes_meta.h>
1223
+ #include <ATen/ops/split_with_sizes_copy_meta.h>
1224
+ #include <ATen/ops/sqrt_meta.h>
1225
+ #include <ATen/ops/square_meta.h>
1226
+ #include <ATen/ops/squeeze_meta.h>
1227
+ #include <ATen/ops/squeeze_copy_meta.h>
1228
+ #include <ATen/ops/sspaddmm_meta.h>
1229
+ #include <ATen/ops/stack_meta.h>
1230
+ #include <ATen/ops/std_meta.h>
1231
+ #include <ATen/ops/std_mean_meta.h>
1232
+ #include <ATen/ops/stft_meta.h>
1233
+ #include <ATen/ops/stride_meta.h>
1234
+ #include <ATen/ops/sub_meta.h>
1235
+ #include <ATen/ops/subtract_meta.h>
1236
+ #include <ATen/ops/sum_meta.h>
1237
+ #include <ATen/ops/sum_to_size_meta.h>
1238
+ #include <ATen/ops/svd_meta.h>
1239
+ #include <ATen/ops/swapaxes_meta.h>
1240
+ #include <ATen/ops/swapdims_meta.h>
1241
+ #include <ATen/ops/sym_constrain_range_meta.h>
1242
+ #include <ATen/ops/sym_constrain_range_for_size_meta.h>
1243
+ #include <ATen/ops/sym_is_contiguous_meta.h>
1244
+ #include <ATen/ops/sym_numel_meta.h>
1245
+ #include <ATen/ops/sym_size_meta.h>
1246
+ #include <ATen/ops/sym_storage_offset_meta.h>
1247
+ #include <ATen/ops/sym_stride_meta.h>
1248
+ #include <ATen/ops/t_meta.h>
1249
+ #include <ATen/ops/t_copy_meta.h>
1250
+ #include <ATen/ops/take_meta.h>
1251
+ #include <ATen/ops/take_along_dim_meta.h>
1252
+ #include <ATen/ops/tan_meta.h>
1253
+ #include <ATen/ops/tanh_meta.h>
1254
+ #include <ATen/ops/tanh_backward_meta.h>
1255
+ #include <ATen/ops/tensor_split_meta.h>
1256
+ #include <ATen/ops/tensordot_meta.h>
1257
+ #include <ATen/ops/thnn_conv2d_meta.h>
1258
+ #include <ATen/ops/threshold_meta.h>
1259
+ #include <ATen/ops/threshold_backward_meta.h>
1260
+ #include <ATen/ops/tile_meta.h>
1261
+ #include <ATen/ops/to_meta.h>
1262
+ #include <ATen/ops/to_dense_meta.h>
1263
+ #include <ATen/ops/to_dense_backward_meta.h>
1264
+ #include <ATen/ops/to_mkldnn_meta.h>
1265
+ #include <ATen/ops/to_mkldnn_backward_meta.h>
1266
+ #include <ATen/ops/to_padded_tensor_meta.h>
1267
+ #include <ATen/ops/to_sparse_meta.h>
1268
+ #include <ATen/ops/to_sparse_bsc_meta.h>
1269
+ #include <ATen/ops/to_sparse_bsr_meta.h>
1270
+ #include <ATen/ops/to_sparse_csc_meta.h>
1271
+ #include <ATen/ops/to_sparse_csr_meta.h>
1272
+ #include <ATen/ops/topk_meta.h>
1273
+ #include <ATen/ops/trace_meta.h>
1274
+ #include <ATen/ops/trace_backward_meta.h>
1275
+ #include <ATen/ops/transpose_meta.h>
1276
+ #include <ATen/ops/transpose_copy_meta.h>
1277
+ #include <ATen/ops/trapezoid_meta.h>
1278
+ #include <ATen/ops/trapz_meta.h>
1279
+ #include <ATen/ops/triangular_solve_meta.h>
1280
+ #include <ATen/ops/tril_meta.h>
1281
+ #include <ATen/ops/tril_indices_meta.h>
1282
+ #include <ATen/ops/triplet_margin_loss_meta.h>
1283
+ #include <ATen/ops/triu_meta.h>
1284
+ #include <ATen/ops/triu_indices_meta.h>
1285
+ #include <ATen/ops/true_divide_meta.h>
1286
+ #include <ATen/ops/trunc_meta.h>
1287
+ #include <ATen/ops/type_as_meta.h>
1288
+ #include <ATen/ops/unbind_meta.h>
1289
+ #include <ATen/ops/unbind_copy_meta.h>
1290
+ #include <ATen/ops/unflatten_meta.h>
1291
+ #include <ATen/ops/unflatten_dense_tensors_meta.h>
1292
+ #include <ATen/ops/unfold_meta.h>
1293
+ #include <ATen/ops/unfold_backward_meta.h>
1294
+ #include <ATen/ops/unfold_copy_meta.h>
1295
+ #include <ATen/ops/uniform_meta.h>
1296
+ #include <ATen/ops/unique_consecutive_meta.h>
1297
+ #include <ATen/ops/unique_dim_meta.h>
1298
+ #include <ATen/ops/unique_dim_consecutive_meta.h>
1299
+ #include <ATen/ops/unsafe_chunk_meta.h>
1300
+ #include <ATen/ops/unsafe_split_meta.h>
1301
+ #include <ATen/ops/unsafe_split_with_sizes_meta.h>
1302
+ #include <ATen/ops/unsqueeze_meta.h>
1303
+ #include <ATen/ops/unsqueeze_copy_meta.h>
1304
+ #include <ATen/ops/upsample_bicubic2d_meta.h>
1305
+ #include <ATen/ops/upsample_bicubic2d_backward_meta.h>
1306
+ #include <ATen/ops/upsample_bilinear2d_meta.h>
1307
+ #include <ATen/ops/upsample_bilinear2d_backward_meta.h>
1308
+ #include <ATen/ops/upsample_linear1d_meta.h>
1309
+ #include <ATen/ops/upsample_linear1d_backward_meta.h>
1310
+ #include <ATen/ops/upsample_nearest1d_meta.h>
1311
+ #include <ATen/ops/upsample_nearest1d_backward_meta.h>
1312
+ #include <ATen/ops/upsample_nearest2d_meta.h>
1313
+ #include <ATen/ops/upsample_nearest2d_backward_meta.h>
1314
+ #include <ATen/ops/upsample_nearest3d_meta.h>
1315
+ #include <ATen/ops/upsample_nearest3d_backward_meta.h>
1316
+ #include <ATen/ops/upsample_trilinear3d_meta.h>
1317
+ #include <ATen/ops/upsample_trilinear3d_backward_meta.h>
1318
+ #include <ATen/ops/value_selecting_reduction_backward_meta.h>
1319
+ #include <ATen/ops/values_meta.h>
1320
+ #include <ATen/ops/values_copy_meta.h>
1321
+ #include <ATen/ops/vander_meta.h>
1322
+ #include <ATen/ops/var_meta.h>
1323
+ #include <ATen/ops/var_mean_meta.h>
1324
+ #include <ATen/ops/vdot_meta.h>
1325
+ #include <ATen/ops/view_meta.h>
1326
+ #include <ATen/ops/view_as_meta.h>
1327
+ #include <ATen/ops/view_as_complex_meta.h>
1328
+ #include <ATen/ops/view_as_complex_copy_meta.h>
1329
+ #include <ATen/ops/view_as_real_meta.h>
1330
+ #include <ATen/ops/view_as_real_copy_meta.h>
1331
+ #include <ATen/ops/view_copy_meta.h>
1332
+ #include <ATen/ops/vsplit_meta.h>
1333
+ #include <ATen/ops/vstack_meta.h>
1334
+ #include <ATen/ops/where_meta.h>
1335
+ #include <ATen/ops/xlogy_meta.h>
1336
+ #include <ATen/ops/xor_meta.h>
1337
+ #include <ATen/ops/zero_meta.h>
1338
+ #include <ATen/ops/zeros_meta.h>
1339
+ #include <ATen/ops/zeros_like_meta.h>
1340
+
1341
+ namespace at {
1342
+
1343
+ namespace meta {
1344
+
1345
+
1346
+
1347
+ } // namespace meta
1348
+ } // namespace at
1349
+
1350
+ #else
1351
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1352
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/MemoryOverlap.h>
4
+ #include <ATen/Tensor.h>
5
+ #include <c10/core/DispatchKey.h>
6
+ #include <c10/core/DispatchKeySet.h>
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/TensorImpl.h>
9
+ #include <c10/util/ArrayRef.h>
10
+ #include <c10/util/Exception.h>
11
+ #include <c10/util/Metaprogramming.h>
12
+ #include <c10/util/irange.h>
13
+
14
+ namespace at::native {
15
+ struct NestedTensorImpl;
16
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
17
+ int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
18
+ at::Tensor construct_nested_strides(const at::Tensor& nested_size);
19
+ at::Tensor construct_offsets(const at::Tensor& nested_size);
20
+
21
+ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
22
+ explicit NestedTensorImpl(
23
+ Storage storage,
24
+ c10::DispatchKeySet key_set,
25
+ const caffe2::TypeMeta data_type,
26
+ at::Tensor nested_sizes,
27
+ at::Tensor nested_strides,
28
+ at::Tensor storage_offsets);
29
+
30
+ explicit NestedTensorImpl(
31
+ const at::Tensor& buffer,
32
+ at::Tensor nested_sizes,
33
+ at::Tensor nested_strides,
34
+ at::Tensor storage_offsets);
35
+ // assume contiguous, `nested_strides` and `offsets`
36
+ // can be inferred from `nested_sizes`
37
+ explicit NestedTensorImpl(
38
+ const at::Tensor& buffer,
39
+ const at::Tensor& nested_sizes);
40
+
41
+ // This constructor is used creating view tensors from nested tensors
42
+ explicit NestedTensorImpl(
43
+ c10::TensorImpl::ImplType impl_type,
44
+ const at::Tensor& base_tensor,
45
+ at::Tensor nested_sizes,
46
+ at::Tensor nested_strides,
47
+ at::Tensor storage_offsets);
48
+
49
+ // TODO: don't expose private implementation details like this; in
50
+ // particular, resizing this tensor will mess up our dim() and
51
+ // callers cannot fix it.
52
+ const Tensor& get_nested_sizes() const {
53
+ return nested_sizes_;
54
+ }
55
+ // TODO: don't expose private implementation details like this
56
+ const Tensor& get_nested_strides() const {
57
+ return nested_strides_;
58
+ }
59
+ const Tensor& get_storage_offsets() const {
60
+ return storage_offsets_;
61
+ }
62
+ // Returns nullopt if the ith dimension is irregular. The ith dimension
63
+ // of a NestedTensor is regular if the unbound tensors match in
64
+ // size at the (i-1)th dimension.
65
+ std::optional<int64_t> opt_size(int64_t d) const;
66
+
67
+ int64_t size(int64_t d) const {
68
+ std::optional<int64_t> optional_size = this->opt_size(d);
69
+ TORCH_CHECK(
70
+ optional_size.has_value(),
71
+ "Given dimension ",
72
+ d,
73
+ " is irregular and does not have a size.");
74
+ return *optional_size;
75
+ }
76
+ /**
77
+ * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
78
+ *
79
+ * The buffer tensor created by this function shares the same storage_impl as
80
+ * the original nested tensor, and therefore can be seen as a view.
81
+ *
82
+ * @return A newly constructed view tensor
83
+ */
84
+ at::Tensor get_buffer() const {
85
+ TORCH_CHECK(
86
+ nested_tensor_impl_is_contiguous(this),
87
+ "NestedTensor must be contiguous to get buffer.");
88
+ return get_unsafe_storage_as_tensor();
89
+ }
90
+ /**
91
+ * If possible use get_buffer() instead. This function returns the storage
92
+ * as a tensor directly, which is not safe to use in general. If using this
93
+ * function, The caller must ensure to account for nested_sizes,
94
+ * nested_strides and storage_offsets.
95
+ *
96
+ * @return A newly constructed view tensor
97
+ */
98
+ at::Tensor get_unsafe_storage_as_tensor() const {
99
+ auto buffer_key_set_ = generate_buffer_key_set();
100
+ const auto buffer_size = get_buffer_size();
101
+ auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
102
+ c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
103
+ buffer_tensor_impl->set_sizes_contiguous(
104
+ c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
105
+ return Tensor(buffer_tensor_impl);
106
+ }
107
+
108
+ size_t get_buffer_size() const {
109
+ return storage_.nbytes() / data_type_.itemsize();
110
+ }
111
+
112
+ protected:
113
+ const char* tensorimpl_type_name() const override;
114
+
115
+ // TODO: numel_custom and is_contiguous_custom can be profitably overridden
116
+ // with real implementations
117
+ int64_t numel_custom() const override;
118
+ c10::SymInt sym_numel_custom() const override;
119
+ c10::SymBool sym_is_contiguous_custom(
120
+ MemoryFormat /*memory_format*/) const override;
121
+ int64_t size_custom(int64_t d) const override {
122
+ return this->size(d);
123
+ }
124
+ c10::SymInt sym_size_custom(int64_t d) const override {
125
+ return c10::SymInt{this->size(d)};
126
+ }
127
+ IntArrayRef sizes_custom() const override;
128
+ c10::SymIntArrayRef sym_sizes_custom() const override;
129
+ IntArrayRef strides_custom() const override;
130
+ c10::SymIntArrayRef sym_strides_custom() const override;
131
+
132
+ // this one is real
133
+ int64_t dim_custom() const override;
134
+
135
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
136
+ const c10::VariableVersion& version_counter,
137
+ bool allow_tensor_metadata_change) const override;
138
+
139
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
140
+ c10::VariableVersion&& version_counter,
141
+ bool allow_tensor_metadata_change) const override;
142
+
143
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
144
+ copy_tensor_metadata(
145
+ /*src_impl=*/impl.get(),
146
+ /*dest_impl=*/this,
147
+ /*version_counter=*/version_counter(),
148
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
149
+ }
150
+
151
+ private:
152
+ // Must be called after any changes to our dim() to sync the state
153
+ // to TensorImpl.
154
+ void refresh_dim();
155
+
156
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
157
+ const at::Tensor nested_sizes_, nested_strides_;
158
+ // The starting positions of the underlying tensors in contiguous buffer
159
+ // i.e. the buffer memory offsets to get the underlying tensors
160
+ // The reason to keep this metadata is that, without strong enough constraint
161
+ // it cannot be derived from `nested_sizes_`
162
+ // and `nested_strides_`:
163
+ // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
164
+ // this can happen e.g. after slicing a nested tensor
165
+ // 2. when multiple tensors share a same memory
166
+ // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
167
+ // Some strong enough constraints are:
168
+ // 1. every underlying tensor is contiguous in memory
169
+ // && nesting in ascending order
170
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
171
+ const at::Tensor storage_offsets_;
172
+ // NOTE: -1 here means the size is missing
173
+ // Optional to allow it to be computed lazily from nested.
174
+ // TODO: maybe we can remove this metadata since
175
+ // we can compute it from `nested_sizes_`
176
+ mutable std::optional<std::vector<int64_t>> opt_sizes_;
177
+
178
+ template <typename VariableVersion>
179
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
180
+ VariableVersion&& version_counter,
181
+ bool allow_tensor_metadata_change) const;
182
+
183
+ /**
184
+ * Generates a non-nested key_set from a nested tensor.
185
+ *
186
+ * For many nested tensor kernel implementations a buffer tensor
187
+ * is generated and redispatched to a non-nested kernel this function
188
+ * generates the key set used by that buffer tensor
189
+ *
190
+ * @return Appropriate key set for non-nested tensor
191
+ */
192
+ inline c10::DispatchKeySet generate_buffer_key_set() const {
193
+ auto buffer_key_set = this->key_set();
194
+ const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
195
+ // Remove nested tensor specific keys
196
+ buffer_key_set = buffer_key_set -
197
+ c10::DispatchKeySet{
198
+ c10::DispatchKey::NestedTensor,
199
+ c10::DispatchKey::AutogradNestedTensor};
200
+
201
+ // Add dense tensor specific keys
202
+ buffer_key_set =
203
+ buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
204
+ buffer_key_set = Autograd
205
+ ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
206
+ : buffer_key_set;
207
+
208
+ return buffer_key_set;
209
+ }
210
+ };
211
+
212
+ inline NestedTensorImpl* get_nested_tensor_impl_or_null(
213
+ const at::Tensor& tensor) {
214
+ if (tensor.is_nested()) {
215
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
216
+ }
217
+ return nullptr;
218
+ }
219
+
220
+ inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
221
+ TORCH_CHECK(
222
+ tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
223
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
224
+ }
225
+
226
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
227
+ int64_t ntensors = nt->size(0);
228
+ if (ntensors == 0) {
229
+ return true;
230
+ }
231
+ const Tensor &sizemat = nt->get_nested_sizes(),
232
+ &stridemat = nt->get_nested_strides();
233
+ const int64_t* offsets_ptr =
234
+ nt->get_storage_offsets().const_data_ptr<int64_t>();
235
+ int64_t orig_dim = sizemat.size(1);
236
+ // nesting scalars
237
+ if (orig_dim == 0) {
238
+ // each scalar must be contiguous
239
+ // if there is blank memory between underlying scalars
240
+ for (int64_t i = 0; i < ntensors; i++) {
241
+ if (offsets_ptr[i] != i) {
242
+ return false;
243
+ }
244
+ }
245
+ }
246
+ // nesting tensors
247
+ else {
248
+ // if any underlying tensor is non-contiguous
249
+ const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
250
+ *stridemat_ptr = stridemat.const_data_ptr<int64_t>();
251
+ for (int64_t i = 0; i < ntensors; i++) {
252
+ if (stridemat_ptr[orig_dim - 1] != 1) {
253
+ return false;
254
+ }
255
+ int64_t product = sizemat_ptr[orig_dim - 1];
256
+ for (int64_t j = orig_dim - 2; j >= 0; j--) {
257
+ if (stridemat_ptr[j] != product) {
258
+ return false;
259
+ }
260
+ product *= sizemat_ptr[j];
261
+ }
262
+ sizemat_ptr += orig_dim;
263
+ stridemat_ptr += orig_dim;
264
+ }
265
+ // if there is blank memory between underlying tensors
266
+ if (offsets_ptr[0] != 0) {
267
+ return false;
268
+ }
269
+ sizemat_ptr = sizemat.const_data_ptr<int64_t>();
270
+ stridemat_ptr = stridemat.const_data_ptr<int64_t>();
271
+ for (int64_t i = 1; i < ntensors; i++) {
272
+ if (offsets_ptr[i] !=
273
+ offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
274
+ return false;
275
+ }
276
+ sizemat_ptr += orig_dim;
277
+ stridemat_ptr += orig_dim;
278
+ }
279
+ }
280
+ // everything is fine
281
+ return true;
282
+ }
283
+
284
+ inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
285
+ return get_nested_tensor_impl(tensor)->get_nested_sizes();
286
+ }
287
+
288
+ } // namespace at::native
289
+
290
+ #else
291
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
292
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #ifdef __HIPCC__
5
+ #include <hip/hip_runtime.h>
6
+ #endif
7
+
8
+ #include <c10/macros/Macros.h>
9
+ #include <c10/util/BFloat16.h>
10
+ #include <c10/util/Float8_e4m3fn.h>
11
+ #include <c10/util/Float8_e4m3fnuz.h>
12
+ #include <c10/util/Float8_e5m2.h>
13
+ #include <c10/util/Float8_e5m2fnuz.h>
14
+ #include <c10/util/Half.h>
15
+ #include <c10/util/complex.h>
16
+
17
+ #include <cmath>
18
+ #include <type_traits>
19
+
20
+ namespace at {
21
+
22
+ // std::isnan isn't performant to use on integral types; it will
23
+ // (uselessly) convert to floating point and then do the test.
24
+ // This function is.
25
+
26
+ template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
27
+ inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
28
+ return false;
29
+ }
30
+
31
+ template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
32
+ inline C10_HOST_DEVICE bool _isnan(T val) {
33
+ #if defined(__CUDACC__) || defined(__HIPCC__)
34
+ return ::isnan(val);
35
+ #else
36
+ return std::isnan(val);
37
+ #endif
38
+ }
39
+
40
+ template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
41
+ inline C10_HOST_DEVICE bool _isnan(T val) {
42
+ return std::isnan(val.real()) || std::isnan(val.imag());
43
+ }
44
+
45
+ template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
46
+ inline C10_HOST_DEVICE bool _isnan(T val) {
47
+ return at::_isnan(static_cast<float>(val));
48
+ }
49
+
50
+ template <
51
+ typename T,
52
+ std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
53
+ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
54
+ return at::_isnan(static_cast<float>(val));
55
+ }
56
+
57
+ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
58
+ return at::_isnan(static_cast<float>(val));
59
+ }
60
+
61
+ template <
62
+ typename T,
63
+ std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
64
+ inline C10_HOST_DEVICE bool _isnan(T val) {
65
+ return val.isnan();
66
+ }
67
+
68
+ template <
69
+ typename T,
70
+ std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
71
+ inline C10_HOST_DEVICE bool _isnan(T val) {
72
+ return val.isnan();
73
+ }
74
+
75
+ template <
76
+ typename T,
77
+ std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
78
+ inline C10_HOST_DEVICE bool _isnan(T val) {
79
+ return val.isnan();
80
+ }
81
+
82
+ template <
83
+ typename T,
84
+ std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
85
+ inline C10_HOST_DEVICE bool _isnan(T val) {
86
+ return val.isnan();
87
+ }
88
+
89
+ // std::isinf isn't performant to use on integral types; it will
90
+ // (uselessly) convert to floating point and then do the test.
91
+ // This function is.
92
+
93
+ template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
94
+ inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
95
+ return false;
96
+ }
97
+
98
+ template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
99
+ inline C10_HOST_DEVICE bool _isinf(T val) {
100
+ #if defined(__CUDACC__) || defined(__HIPCC__)
101
+ return ::isinf(val);
102
+ #else
103
+ return std::isinf(val);
104
+ #endif
105
+ }
106
+
107
+ inline C10_HOST_DEVICE bool _isinf(at::Half val) {
108
+ return at::_isinf(static_cast<float>(val));
109
+ }
110
+
111
+ inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
112
+ return at::_isinf(static_cast<float>(val));
113
+ }
114
+
115
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
116
+ return val.isinf();
117
+ }
118
+
119
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
120
+ return false;
121
+ }
122
+
123
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
124
+ return false;
125
+ }
126
+
127
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
128
+ return false;
129
+ }
130
+
131
+ template <typename T>
132
+ C10_HOST_DEVICE inline T exp(T x) {
133
+ static_assert(
134
+ !std::is_same_v<T, double>,
135
+ "this template must be used with float or less precise type");
136
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
137
+ // use __expf fast approximation for peak bandwidth
138
+ return __expf(x);
139
+ #else
140
+ return ::exp(x);
141
+ #endif
142
+ }
143
+
144
+ template <>
145
+ C10_HOST_DEVICE inline double exp<double>(double x) {
146
+ return ::exp(x);
147
+ }
148
+
149
+ template <typename T>
150
+ C10_HOST_DEVICE inline T log(T x) {
151
+ static_assert(
152
+ !std::is_same_v<T, double>,
153
+ "this template must be used with float or less precise type");
154
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
155
+ // use __logf fast approximation for peak bandwidth
156
+ return __logf(x);
157
+ #else
158
+ return ::log(x);
159
+ #endif
160
+ }
161
+
162
+ template <>
163
+ C10_HOST_DEVICE inline double log<double>(double x) {
164
+ return ::log(x);
165
+ }
166
+
167
+ template <typename T>
168
+ C10_HOST_DEVICE inline T log1p(T x) {
169
+ static_assert(
170
+ !std::is_same_v<T, double>,
171
+ "this template must be used with float or less precise type");
172
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
173
+ // use __logf fast approximation for peak bandwidth
174
+ // NOTE: There is no __log1pf so unfortunately we lose precision.
175
+ return __logf(1.0f + x);
176
+ #else
177
+ return ::log1p(x);
178
+ #endif
179
+ }
180
+
181
+ template <>
182
+ C10_HOST_DEVICE inline double log1p<double>(double x) {
183
+ return ::log1p(x);
184
+ }
185
+
186
+ template <typename T>
187
+ C10_HOST_DEVICE inline T tan(T x) {
188
+ static_assert(
189
+ !std::is_same_v<T, double>,
190
+ "this template must be used with float or less precise type");
191
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
192
+ // use __tanf fast approximation for peak bandwidth
193
+ return __tanf(x);
194
+ #else
195
+ return ::tan(x);
196
+ #endif
197
+ }
198
+
199
+ template <>
200
+ C10_HOST_DEVICE inline double tan<double>(double x) {
201
+ return ::tan(x);
202
+ }
203
+
204
+ } // namespace at
205
+
206
+ #else
207
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
208
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <algorithm>
5
+ #include <atomic>
6
+ #include <cstddef>
7
+ #include <exception>
8
+
9
+ #ifdef _OPENMP
10
+ #define INTRA_OP_PARALLEL
11
+
12
+ #include <omp.h>
13
+ #endif
14
+
15
+ #ifdef _OPENMP
16
+ namespace at::internal {
17
+ template <typename F>
18
+ inline void invoke_parallel(
19
+ int64_t begin,
20
+ int64_t end,
21
+ int64_t grain_size,
22
+ const F& f) {
23
+ std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
24
+ std::exception_ptr eptr;
25
+
26
+ #pragma omp parallel
27
+ {
28
+ // choose number of tasks based on grain size and number of threads
29
+ // can't use num_threads clause due to bugs in GOMP's thread pool (See
30
+ // #32008)
31
+ int64_t num_threads = omp_get_num_threads();
32
+ if (grain_size > 0) {
33
+ num_threads = std::min(num_threads, divup((end - begin), grain_size));
34
+ }
35
+
36
+ int64_t tid = omp_get_thread_num();
37
+ int64_t chunk_size = divup((end - begin), num_threads);
38
+ int64_t begin_tid = begin + tid * chunk_size;
39
+ if (begin_tid < end) {
40
+ try {
41
+ internal::ThreadIdGuard tid_guard(tid);
42
+ f(begin_tid, std::min(end, chunk_size + begin_tid));
43
+ } catch (...) {
44
+ if (!err_flag.test_and_set()) {
45
+ eptr = std::current_exception();
46
+ }
47
+ }
48
+ }
49
+ }
50
+ if (eptr) {
51
+ std::rethrow_exception(eptr);
52
+ }
53
+ }
54
+ } // namespace at::internal
55
+ #endif // _OPENMP
56
+
57
+ #else
58
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
59
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h ADDED
The diff for this file is too large to render. See raw diff
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h ADDED
The diff for this file is too large to render. See raw diff
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+
7
+ constexpr int32_t num_sdp_backends = 5;
8
+ enum class SDPBackend {
9
+ error = -1,
10
+ math = 0,
11
+ flash_attention = 1,
12
+ efficient_attention = 2,
13
+ cudnn_attention = 3,
14
+ overrideable = 4
15
+ };
16
+
17
+ } // namespace at
18
+
19
+ #else
20
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
21
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Scalar.h ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Scalar.h>
5
+
6
+ #else
7
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
8
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/Storage.h>
5
+ #include <c10/core/StorageImpl.h>
6
+ #include <c10/util/intrusive_ptr.h>
7
+
8
+ namespace at {
9
+
10
+ class TensorBase;
11
+
12
+ // Here we define a series of utils to create/manipulate ATen backed
13
+ // c10 storage implementations.
14
+
15
+ /**
16
+ * Create a new shared memory storage impl managed by file descriptor
17
+ *
18
+ * @param size size in bytes
19
+ */
20
+ C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
21
+
22
+ /**
23
+ * Copy src to dst
24
+ * Caller must guarantee the validness of the storage objects
25
+ * during the entire copy process, esp. when it's async.
26
+ *
27
+ * This can probably live in c10 namespace later if needed,
28
+ * but for now keep it in at to keep implementation simple.
29
+ *
30
+ * @param dst dst tensor
31
+ * @param src src tensor
32
+ * @param non_blocking (default false) whether this operation blocks caller
33
+ */
34
+ C10_EXPORT void storage_copy(
35
+ c10::Storage& dst,
36
+ const c10::Storage& src,
37
+ bool non_blocking = false);
38
+
39
+ /**
40
+ * In place change the storage to shm based.
41
+ *
42
+ * This is only applicable to CPU tensors not already shared.
43
+ * Otherwise, it's a no op to mirror the THP tensor behavior:
44
+ * https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
45
+ *
46
+ * @param t a tensor
47
+ */
48
+ C10_EXPORT void share_memory_(TensorBase& t);
49
+
50
+ } // namespace at
51
+
52
+ #else
53
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
54
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/core/TensorAccessor.h>
4
+
5
+ #else
6
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
7
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/SafePyObject.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <unordered_map>
7
+
8
+ namespace at::impl {
9
+
10
+ struct TORCH_API ThreadLocalPythonObjects {
11
+ static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
12
+ static const std::shared_ptr<SafePyObject>& get(const std::string& key);
13
+ static bool contains(const std::string& key);
14
+
15
+ static const ThreadLocalPythonObjects& get_state();
16
+ static void set_state(ThreadLocalPythonObjects state);
17
+
18
+ private:
19
+ std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
20
+ };
21
+
22
+ } // namespace at::impl
23
+
24
+ #else
25
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
26
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/InferenceMode.h>
5
+ #include <c10/core/impl/LocalDispatchKeySet.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/ThreadLocalDebugInfo.h>
8
+
9
+ #include <ATen/FuncTorchTLS.h>
10
+ #include <ATen/PythonTorchFunctionTLS.h>
11
+ #include <ATen/SavedTensorHooks.h>
12
+ #include <ATen/ThreadLocalPythonObjects.h>
13
+ #include <ATen/record_function.h>
14
+ #include <c10/core/impl/PythonDispatcherTLS.h>
15
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
16
+
17
+ namespace at {
18
+
19
+ // Thread local state contains values that are preserved across
20
+ // thread boundaries (e.g. at::launch/JIT fork, autograd).
21
+ // Note at::parallel_for doesn't preserve TLS across thread boundaries.
22
+ class TORCH_API ThreadLocalState {
23
+ public:
24
+ // Saves the thread local variables' values and
25
+ // returns them as a ThreadLocalState
26
+ ThreadLocalState();
27
+
28
+ // set_grad_mode - force the value of the grad mode TLS in
29
+ // the current state object. This is used for example in the
30
+ // autograd engine.
31
+ void set_grad_mode(bool enabled);
32
+
33
+ // set_multithreading_enabled - force the value of the multithreadinmaximum
34
+ // threads TLS in
35
+ // the current state object. This is used for example in the
36
+ // autograd engine.
37
+ void set_multithreading_enabled(bool enabled);
38
+
39
+ // Sets thread local variables in the current thread,
40
+ // according to the thread boundary specified
41
+ static void setThreadLocalState(const ThreadLocalState& state);
42
+
43
+ private:
44
+ c10::impl::LocalDispatchKeySet dispatch_key_;
45
+
46
+ // ThreadLocalDebugInfo does not change after being created
47
+ // with DebugInfoGuard
48
+ std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
49
+
50
+ // RecordFunction TLS
51
+ RecordFunctionTLS rf_tls_;
52
+
53
+ // TLS for out-of-tree functorch
54
+ // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
55
+ // pointer (spoiler alert: it's due to the indirection)
56
+ // This needs to be a shared_ptr instead of a unique_ptr because
57
+ // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
58
+ // consider adding an explicit copy constructor for ThreadLocalState in the
59
+ // future but I didn't want to add one just for this.
60
+ std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
61
+
62
+ // TLS for AutogradModes
63
+ AutogradState autograd_tls_;
64
+
65
+ // TLS for enable_torch_dispatch_mode
66
+ c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
67
+
68
+ // TLS for enable_python_dispatcher
69
+ c10::impl::PyInterpreter* python_dispatcher_state_;
70
+
71
+ // TLS for __torch_function__ (mode and disable_torch_function)
72
+ at::impl::PythonTorchFunctionTLS python_torch_function_state_;
73
+
74
+ // TLS for saved tensors default hooks
75
+ at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
76
+
77
+ bool functionalization_reapply_views_state_;
78
+
79
+ bool dtensor_allow_implicit_replication_;
80
+
81
+ // TLS for arbitrary python objects that is registered via hooks
82
+ at::impl::ThreadLocalPythonObjects saved_objects_;
83
+
84
+ #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
85
+ !defined(BUILD_LITE_INTERPRETER)
86
+ // TLS for autocast dtypes
87
+ std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
88
+ autocast_dtypes_{};
89
+ #endif
90
+
91
+ friend class ThreadLocalStateGuard;
92
+ };
93
+
94
+ // Guard to set and reset the thread local state
95
+ class TORCH_API ThreadLocalStateGuard {
96
+ public:
97
+ explicit ThreadLocalStateGuard(const ThreadLocalState& state)
98
+ : prev_state_(ThreadLocalState()) {
99
+ // set the given state across the thread boundary
100
+ ThreadLocalState::setThreadLocalState(state);
101
+ }
102
+ ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete;
103
+ ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete;
104
+ ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete;
105
+ ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete;
106
+
107
+ ~ThreadLocalStateGuard() {
108
+ // restore previously set variables
109
+ ThreadLocalState::setThreadLocalState(prev_state_);
110
+ }
111
+
112
+ private:
113
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
114
+ const ThreadLocalState prev_state_;
115
+ };
116
+
117
+ template <typename T>
118
+ auto wrapPropagateTLSState(T callback) {
119
+ return [tls_state = ThreadLocalState(),
120
+ callback = std::move(callback)](auto&&... args) {
121
+ ThreadLocalStateGuard g(tls_state);
122
+ // Propagate value returned by callback().
123
+ return callback(std::forward<decltype(args)>(args)...);
124
+ };
125
+ }
126
+
127
+ } // namespace at
128
+
129
+ #else
130
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
131
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Utils.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/EmptyTensor.h>
5
+ #include <ATen/Formatting.h>
6
+ #include <ATen/core/ATenGeneral.h>
7
+ #include <ATen/core/Generator.h>
8
+ #include <c10/core/ScalarType.h>
9
+ #include <c10/core/StorageImpl.h>
10
+ #include <c10/core/UndefinedTensorImpl.h>
11
+ #include <c10/util/ArrayRef.h>
12
+ #include <c10/util/Exception.h>
13
+ #include <c10/util/accumulate.h>
14
+ #include <c10/util/irange.h>
15
+
16
+ #include <algorithm>
17
+
18
+ #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
19
+ TypeName(const TypeName&) = delete; \
20
+ void operator=(const TypeName&) = delete
21
+
22
+ namespace at {
23
+
24
+ TORCH_API int _crash_if_asan(int /*arg*/);
25
+
26
+ // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
27
+ // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
28
+ // Once cat is ported entirely to ATen this can be deleted!
29
+ inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
30
+ ArrayRef<Tensor> tensors,
31
+ const char* name,
32
+ int pos,
33
+ c10::DeviceType device_type,
34
+ ScalarType scalar_type) {
35
+ std::vector<TensorImpl*> unwrapped;
36
+ unwrapped.reserve(tensors.size());
37
+ for (const auto i : c10::irange(tensors.size())) {
38
+ const auto& expr = tensors[i];
39
+ if (expr.layout() != Layout::Strided) {
40
+ TORCH_CHECK(
41
+ false,
42
+ "Expected dense tensor but got ",
43
+ expr.layout(),
44
+ " for sequence element ",
45
+ i,
46
+ " in sequence argument at position #",
47
+ pos,
48
+ " '",
49
+ name,
50
+ "'");
51
+ }
52
+ if (expr.device().type() != device_type) {
53
+ TORCH_CHECK(
54
+ false,
55
+ "Expected object of device type ",
56
+ device_type,
57
+ " but got device type ",
58
+ expr.device().type(),
59
+ " for sequence element ",
60
+ i,
61
+ " in sequence argument at position #",
62
+ pos,
63
+ " '",
64
+ name,
65
+ "'");
66
+ }
67
+ if (expr.scalar_type() != scalar_type) {
68
+ TORCH_CHECK(
69
+ false,
70
+ "Expected object of scalar type ",
71
+ scalar_type,
72
+ " but got scalar type ",
73
+ expr.scalar_type(),
74
+ " for sequence element ",
75
+ i,
76
+ " in sequence argument at position #",
77
+ pos,
78
+ " '",
79
+ name,
80
+ "'");
81
+ }
82
+ unwrapped.emplace_back(expr.unsafeGetTensorImpl());
83
+ }
84
+ return unwrapped;
85
+ }
86
+
87
+ template <size_t N>
88
+ std::array<int64_t, N> check_intlist(
89
+ ArrayRef<int64_t> list,
90
+ const char* name,
91
+ int pos) {
92
+ if (list.empty()) {
93
+ // TODO: is this necessary? We used to treat nullptr-vs-not in IntList
94
+ // differently with strides as a way of faking optional.
95
+ list = {};
96
+ }
97
+ auto res = std::array<int64_t, N>();
98
+ if (list.size() == 1 && N > 1) {
99
+ res.fill(list[0]);
100
+ return res;
101
+ }
102
+ if (list.size() != N) {
103
+ TORCH_CHECK(
104
+ false,
105
+ "Expected a list of ",
106
+ N,
107
+ " ints but got ",
108
+ list.size(),
109
+ " for argument #",
110
+ pos,
111
+ " '",
112
+ name,
113
+ "'");
114
+ }
115
+ std::copy_n(list.begin(), N, res.begin());
116
+ return res;
117
+ }
118
+
119
+ using at::detail::check_size_nonnegative;
120
+
121
+ namespace detail {
122
+
123
+ template <typename T>
124
+ TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
125
+
126
+ template <typename T>
127
+ TORCH_API Tensor
128
+ tensor_backend(ArrayRef<T> values, const TensorOptions& options);
129
+
130
+ template <typename T>
131
+ TORCH_API Tensor
132
+ tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
133
+
134
+ template <typename T>
135
+ TORCH_API Tensor
136
+ tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
137
+ } // namespace detail
138
+
139
+ } // namespace at
140
+
141
+ #else
142
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
143
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
3
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
4
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
5
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
6
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
7
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
8
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
9
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
10
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
11
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
12
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
13
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
14
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
15
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
16
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
17
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
18
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
19
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
20
+
21
+ // YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
22
+
23
+ // This code was a temporary hack to enable embedding arbitrary C++ structures
24
+ // into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
25
+ // IT __WILL__ BREAK.
26
+
27
+ // This code has been superseded by custom classes:
28
+ // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
29
+
30
+ // Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
31
+ // IN THIS FILE**.
32
+
33
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
34
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
35
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
36
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
37
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
38
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
39
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
40
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
41
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
42
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
43
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
44
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
45
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
46
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
47
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
48
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
49
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
50
+ // STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
51
+
52
+ #include <ATen/TracerMode.h>
53
+ #include <ATen/core/Tensor.h>
54
+
55
+ #ifndef AT_PER_OPERATOR_HEADERS
56
+ #include <ATen/Functions.h>
57
+ #else
58
+ #include <ATen/ops/empty.h>
59
+ #endif
60
+
61
+ namespace at::cpp_custom_type_hack {
62
+
63
+ template <typename T>
64
+ [[deprecated(
65
+ "Use custom classes instead: "
66
+ "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool
67
+ isa(const Tensor& packed) {
68
+ return (packed.scalar_type() == kByte) &&
69
+ (packed.storage().data_ptr().get_deleter() ==
70
+ caffe2::TypeMeta::Make<T>().deleteFn());
71
+ }
72
+
73
+ template <typename T>
74
+ [[deprecated(
75
+ "Use custom classes instead: "
76
+ "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T&
77
+ cast(const Tensor& packed) {
78
+ TORCH_CHECK(
79
+ packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
80
+ TORCH_CHECK(
81
+ packed.storage().data_ptr().get_deleter() ==
82
+ caffe2::TypeMeta::Make<T>().deleteFn(),
83
+ "Expected temporary cpp type wrapper of type ",
84
+ caffe2::TypeMeta::TypeName<T>());
85
+ return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
86
+ }
87
+
88
+ template <typename T>
89
+ [[deprecated(
90
+ "Use custom classes instead: "
91
+ "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor
92
+ create(std::unique_ptr<T> ptr, TensorOptions options) {
93
+ // None of this should trace, so turn off Tracer dispatching
94
+ at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
95
+ at::tracer::impl::NoTracerDispatchMode tracer_guard;
96
+
97
+ // We store this instance away in a Tensor and register a deleter function
98
+ // so that we do not leak memory. On the other side, we pull out the storage's
99
+ // data_ptr and get the right typed pointer.
100
+ void* raw_ptr = ptr.release();
101
+ at::DataPtr at_ptr(
102
+ raw_ptr, raw_ptr, caffe2::TypeMeta::Make<T>().deleteFn(), at::kCPU);
103
+
104
+ // size doesn't really matter, but we can align it to the actual size
105
+ // returning variables because one likely want to use this hack from python
106
+ auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
107
+ retval.storage().set_data_ptr_noswap(std::move(at_ptr));
108
+ return retval;
109
+ }
110
+
111
+ } // namespace at::cpp_custom_type_hack
112
+
113
+ #else
114
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
115
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // TODO: Remove once torchvision has been updated to use the ATen header
4
+ #include <ATen/cuda/Atomic.cuh>
5
+
6
+ #else
7
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
8
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // TODO: Remove this header
4
+ #include <ATen/cuda/DeviceUtils.cuh>
5
+
6
+ #else
7
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
8
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include <array>
13
+ #include <stdexcept>
14
+ #include <string>
15
+ #include <type_traits>
16
+
17
+ namespace fbgemm {
18
+
19
+ template <int N, int... Vals>
20
+ constexpr std::enable_if_t<N == sizeof...(Vals), std::array<int, N>>
21
+ array_of_ones() {
22
+ return std::array<int, N>{{Vals...}};
23
+ }
24
+
25
+ template <int N, int... Vals>
26
+ constexpr std::enable_if_t<N != sizeof...(Vals), std::array<int, N>>
27
+ array_of_ones() {
28
+ return array_of_ones<N, Vals..., 1>();
29
+ }
30
+
31
+ template <int N, int... Vals>
32
+ constexpr std::enable_if_t<N == sizeof...(Vals), std::array<int, N>>
33
+ array_of_zeroes() {
34
+ return std::array<int, N>{{Vals...}};
35
+ }
36
+
37
+ template <int N, int... Vals>
38
+ constexpr std::enable_if_t<N != sizeof...(Vals), std::array<int, N>>
39
+ array_of_zeroes() {
40
+ return array_of_zeroes<N, Vals..., 0>();
41
+ }
42
+
43
+ /**
44
+ * @brief A struct to conveniently store all convolution parameters.
45
+ */
46
+ template <int SPATIAL_DIM = 2>
47
+ struct conv_param_t {
48
+ int MB; ///< Mini Batch size
49
+ int IC; ///< Number of Input Channels
50
+ int OC; ///< Number of Output Channels
51
+ std::array<int, SPATIAL_DIM> IN_DIM; ///< Input Image Dimension
52
+ int G; ///< Number of Groups
53
+ std::array<int, SPATIAL_DIM> K; ///< Filter (Kernel) dimensions
54
+ std::array<int, SPATIAL_DIM> stride; //< Strides
55
+ std::array<int, SPATIAL_DIM * 2>
56
+ pad; //< Padding (first SPATIAL_DIM is for prev/top/left padding, second
57
+ // SPATIAL_DIM is for next/bottom/right padding)
58
+ std::array<int, SPATIAL_DIM> dilation; //< Kernel dilation
59
+
60
+ // The following are derived parameters
61
+ std::array<int, SPATIAL_DIM> OUT_DIM; //< Output Image Dimension
62
+ std::array<int, SPATIAL_DIM> IN_DIMP; //< Input Image Dimension Padded
63
+
64
+ // The following is for tranposed convolution
65
+ std::array<int, SPATIAL_DIM>
66
+ output_pad; //< Padding (next/bottom/right padding in output buffer)
67
+ bool transposed;
68
+
69
+ /**
70
+ * @brief Constructor for initializing the convolution parameters.
71
+ */
72
+ conv_param_t(
73
+ int mb,
74
+ int ic,
75
+ int oc,
76
+ std::array<int, SPATIAL_DIM> in_dim,
77
+ int g,
78
+ std::array<int, SPATIAL_DIM> k,
79
+ std::array<int, SPATIAL_DIM> strd,
80
+ std::array<int, SPATIAL_DIM * 2> pd,
81
+ std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>(),
82
+ std::array<int, SPATIAL_DIM> otpt_pd = array_of_zeroes<SPATIAL_DIM>(),
83
+ bool transposed = false)
84
+ : MB(mb),
85
+ IC(ic),
86
+ OC(oc),
87
+ IN_DIM(in_dim),
88
+ G(g),
89
+ K(k),
90
+ stride(strd),
91
+ pad(pd),
92
+ dilation(dilations),
93
+ output_pad(otpt_pd),
94
+ transposed(transposed) {
95
+ if (ic % g != 0) {
96
+ throw std::runtime_error(
97
+ "groups = " + std::to_string(g) +
98
+ " does not divide number of input channels = " + std::to_string(ic));
99
+ }
100
+ if (oc % g != 0) {
101
+ throw std::runtime_error(
102
+ "groups = " + std::to_string(g) +
103
+ " does not divide number of output channels = " + std::to_string(oc));
104
+ }
105
+
106
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
107
+ if (transposed) {
108
+ this->IN_DIMP[d] = this->IN_DIM[d] +
109
+ (this->dilation[d] * (this->K[d] - 1) - this->pad[d]) +
110
+ (this->dilation[d] * (this->K[d] - 1) - this->pad[SPATIAL_DIM + d]);
111
+ this->OUT_DIM[d] = (this->IN_DIM[d] - 1) * this->stride[d] -
112
+ this->pad[d] - this->pad[SPATIAL_DIM + d] +
113
+ this->dilation[d] * (this->K[d] - 1) + output_pad[d] + 1;
114
+ } else {
115
+ IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
116
+ OUT_DIM[d] =
117
+ (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
118
+ }
119
+ }
120
+ }
121
+
122
+ /**
123
+ * @brief Helper function to get convolution parameters as string.
124
+ */
125
+ std::string toString() const {
126
+ std::string dim_string[3] = {"T", "H", "W"};
127
+
128
+ std::string out;
129
+ out += "MB:" + std::to_string(MB) + ", ";
130
+ out += "IC:" + std::to_string(IC) + ", ";
131
+ out += "OC:" + std::to_string(OC) + ", ";
132
+ if constexpr (SPATIAL_DIM <= 3) {
133
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
134
+ out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
135
+ std::to_string(IN_DIM[d]) + ", ";
136
+ }
137
+ } else {
138
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
139
+ out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", ";
140
+ }
141
+ }
142
+ out += "G:" + std::to_string(G) + ", ";
143
+ if constexpr (SPATIAL_DIM <= 3) {
144
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
145
+ out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
146
+ std::to_string(K[d]) + ", ";
147
+ }
148
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
149
+ out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
150
+ std::to_string(stride[d]) + ", ";
151
+ }
152
+ for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
153
+ out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
154
+ std::to_string(pad[d]) + ", ";
155
+ }
156
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
157
+ out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
158
+ std::to_string(dilation[d]);
159
+ if (d < SPATIAL_DIM - 1) {
160
+ out += ", ";
161
+ }
162
+ }
163
+ } else {
164
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
165
+ out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
166
+ }
167
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
168
+ out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
169
+ ", ";
170
+ }
171
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
172
+ out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
173
+ if (d < SPATIAL_DIM * 2 - 1) {
174
+ out += ", ";
175
+ }
176
+ }
177
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
178
+ out += "dilation_" + std::to_string(d) + ":" +
179
+ std::to_string(dilation[d]) + ", ";
180
+ }
181
+ }
182
+ if (transposed) {
183
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
184
+ out += "output_padding_" + std::to_string(d) + ":" +
185
+ std::to_string(output_pad[d]) + ", ";
186
+ }
187
+ }
188
+ return out;
189
+ }
190
+ };
191
+ } // namespace fbgemm
192
+
193
+ #else
194
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
195
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h ADDED
@@ -0,0 +1,1515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ /**
13
+ * Top level include file for FBGEMM.
14
+ */
15
+ #include <cassert>
16
+ #include <memory>
17
+ #include "./ConvUtils.h" // @manual
18
+ #include "./FbgemmBuild.h" // @manual
19
+ #include "./FbgemmEmbedding.h" // @manual
20
+ #include "./FbgemmI8DepthwiseAvx2.h" // @manual
21
+ #include "./FbgemmI8DirectconvAvx2.h" // @manual
22
+ #include "./FbgemmI8Spmdm.h" // @manual
23
+ #include "./FloatConversion.h" // @manual
24
+ #include "./QuantUtilsAvx2.h" // @manual
25
+ #include "./Types.h" // @manual
26
+ #include "./Utils.h" // @manual
27
+
28
+ // Turning on this option will print out time breakdown of each stage (e.g.,
29
+ // input packing, the main GEMM kernel, each output processing pipeline).
30
+ // Please note that currently this option won't report accurate timing if
31
+ // multiple threads are used.
32
+ // #define FBGEMM_MEASURE_TIME_BREAKDOWN
33
+
34
+ #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
35
+ #include <chrono>
36
+ #include <iostream>
37
+ extern double packing_time;
38
+ extern double computing_time;
39
+ extern double kernel_time;
40
+ extern double postprocessing_time;
41
+ extern double run_time;
42
+ #endif
43
+
44
+ namespace fbgemm {
45
+
46
+ /**
47
+ * @brief Templatized struct for packing parameters for A and B matrices.
48
+ *
49
+ * @tparam T input type
50
+ * @tparam accT the type used for accumulation
51
+ * @tparam instSet anyarch/avx2/avx512
52
+ * @tparam int8Type an auxiliary template parameter to specialize for 8-bit
53
+ * input types.
54
+ */
55
+ template <
56
+ typename T,
57
+ typename accT,
58
+ inst_set_t instSet,
59
+ typename int8Type = void>
60
+ struct PackingTraits;
61
+
62
+ // type specialized implementation in an include file
63
+ #include "./PackingTraits-inl.h" // @manual
64
+
65
+ /**
66
+ * @brief Base class for packing matrices for higher GEMM performance.
67
+ *
68
+ * Matrix is tiled into blockRows() * blockCols() blocks.
69
+ * Each block is with size blockRowSize() * blockColSize().
70
+ * This class is designed using CRTP
71
+ * (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
72
+ *
73
+ * @tparam PT actual packing type, e.g., PackAWithRowOffset
74
+ */
75
+ template <typename PT, typename inpType, typename accType = std::int32_t>
76
+ class PackMatrix {
77
+ public:
78
+ PackMatrix() = delete; // no default constructor
79
+ PackMatrix(const PackMatrix&) = delete; // no copy
80
+ PackMatrix& operator=(const PackMatrix&) = delete; // no copy
81
+ PackMatrix(PackMatrix&&) = delete; // no move
82
+ PackMatrix& operator=(PackMatrix&& rhs) noexcept = delete; // no move
83
+
84
+ /**
85
+ * @param rows total number of rows in the matrix
86
+ * (packed rows can be less than rows).
87
+ * @param cols total number of columns in the matrix
88
+ * @param pmat A buffer to contain the packed matrix.
89
+ * If nullptr, a buffer owned by PackMatrix will be allocated
90
+ * internally to contain the packed matrix.
91
+ * For non-constant matrices like activation matrices, the client
92
+ * code may want to pass a pre-allocated pmat to avoid the
93
+ * overhead of internal memory allocation everytime a PackMatrix
94
+ * is constructed. The client code can query how big patm should
95
+ * be with packedBufferSize function.
96
+ * @param groups when groups > 1, we compute groups number of GEMMs each
97
+ * multiplies A.rows by A.cols/A.groups matrix with
98
+ * B.rows/B.groups by B.cols matrix (in conventional BLAS
99
+ * terminology, this is a batched GEMM but we use the name group
100
+ * to follow deep learning terminology). The result matrix has
101
+ * dimension A.rows by B.cols*B.groups .
102
+ * A.groups must be same as B.groups, A.groups must divide
103
+ * A.cols, and B.groups must divide B.rows and C.cols.
104
+ */
105
+ PackMatrix(
106
+ std::int32_t rows,
107
+ std::int32_t cols,
108
+ inpType* pmat,
109
+ int groups = 1,
110
+ const BlockingFactors* params = nullptr);
111
+
112
+ /**
113
+ * @return true usually when the matrix is constant matrix (e.g., weight
114
+ * matrices) that can be prepacked
115
+ */
116
+ bool isPrePacked() const {
117
+ return static_cast<const PT*>(this)->isPrePacked();
118
+ }
119
+
120
+ /**
121
+ * @return true if this is the first input matrix in GEMM (i.e., A in C = A *
122
+ * B)
123
+ */
124
+ static bool isA() {
125
+ return PT::isA();
126
+ }
127
+
128
+ /**
129
+ * @brief The size of the buffer used for packing (The size is in number of
130
+ * elements).
131
+ *
132
+ * rows and cols are only used for fully packing, i.e., for B matrix. The
133
+ * client code can use this function to query how big the buffer used for
134
+ * packing should be.
135
+ */
136
+ static int packedBufferSize(
137
+ int rows = 0,
138
+ int cols = 0,
139
+ const BlockingFactors* params = nullptr);
140
+
141
+ FBGEMM_PUSH_WARNING_AND_DISABLE("-Wpragmas")
142
+ FBGEMM_PUSH_WARNING_AND_DISABLE("-Winfinite-recursion")
143
+ /**
144
+ * @return Pointer to a buffer containing row offset results. Some packing
145
+ * objects fuse row offset computation for later requantization step.
146
+ */
147
+ std::int32_t* getRowOffsetBuffer() const {
148
+ return static_cast<const PT*>(this)->getRowOffsetBuffer();
149
+ }
150
+ /**
151
+ * @brief When k loop is also tiled/blocked, this function is used to check if
152
+ * have executed computations for the last k block so that we can perform
153
+ * post-GEMM operations.
154
+ */
155
+ bool isThisLastKBlock(int block_id) const {
156
+ return static_cast<const PT*>(this)->isThisLastKBlock(block_id);
157
+ }
158
+ FBGEMM_POP_WARNING
159
+ FBGEMM_POP_WARNING
160
+
161
+ /**
162
+ * @brief Actual packing of a block of the source matrix in pmat buffer.
163
+ */
164
+ void pack(const block_type_t& block) {
165
+ #if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
166
+ static_cast<PT*>(this)->pack(block);
167
+ #else
168
+ throw std::runtime_error("PackMatrix::pack() not implemented for aarch64");
169
+ #endif // __aarch64__
170
+ }
171
+
172
+ std::int32_t numRows() const {
173
+ return nrows_;
174
+ }
175
+
176
+ std::int32_t numCols() const {
177
+ return ncols_;
178
+ }
179
+
180
+ /**
181
+ * @return The number of rows in each block
182
+ */
183
+ std::int32_t blockRowSize() const {
184
+ return brow_;
185
+ }
186
+
187
+ /**
188
+ * @return The number of columns in each block
189
+ */
190
+ std::int32_t blockColSize() const {
191
+ return bcol_;
192
+ }
193
+
194
+ /**
195
+ * @return The number of blocks along rows
196
+ */
197
+ std::int32_t blockRows() const {
198
+ return nbrow_;
199
+ }
200
+
201
+ /**
202
+ * @return The number of blocks along columns
203
+ */
204
+ std::int32_t blockCols() const {
205
+ return nbcol_;
206
+ }
207
+
208
+ /**
209
+ * @return The number of the rows in the currently packed block of a matrix.
210
+ * For pre-packed (i.e., fully-packed), it's equal to the total number
211
+ * of rows.
212
+ */
213
+ std::int32_t numPackedRows() const {
214
+ return packedBlock_.row_size;
215
+ }
216
+
217
+ /**
218
+ * @return The number of columns in the currently packed block of a matrix.
219
+ * For pre-packed (i.e., fully-packed), it's equal to the number of
220
+ * columns.
221
+ */
222
+ std::int32_t numPackedCols() const {
223
+ return packedBlock_.col_size;
224
+ }
225
+
226
+ /**
227
+ * @return The first row of the block we're working on.
228
+ */
229
+ std::int32_t packedRowStart() const {
230
+ return packedBlock_.row_start;
231
+ }
232
+
233
+ /**
234
+ * @return The first column of the block we're working on.
235
+ */
236
+ std::int32_t packedColStart() const {
237
+ return packedBlock_.col_start;
238
+ }
239
+
240
+ /**
241
+ * @return The beginning of (rowBlockNum, colBlockNum)th block
242
+ */
243
+ inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) {
244
+ return buf_ + blockRowSize() * blockColSize() * rowBlockNum +
245
+ blockRowSize() * blockColSize() * blockCols() * colBlockNum;
246
+ }
247
+
248
+ /**
249
+ * @brief Print the packed block.
250
+ */
251
+ void printPackedMatrix(const std::string& name) {
252
+ static_cast<PT*>(this)->printPackedMatrix(name);
253
+ }
254
+
255
+ /**
256
+ * @return The number of rows in the last row block.
257
+ */
258
+ std::int32_t lastBrow() const {
259
+ return last_brow_;
260
+ }
261
+
262
+ /**
263
+ * @return The number of columns in the last column block.
264
+ */
265
+ std::int32_t lastBcol() const {
266
+ return last_bcol_;
267
+ }
268
+
269
+ int numGroups() const {
270
+ return G_;
271
+ }
272
+
273
+ /**
274
+ * @return True if the last column block has fewer columns than the block
275
+ * size.
276
+ */
277
+ bool isThereColRemainder() const {
278
+ return last_bcol_ != blockColSize();
279
+ }
280
+
281
+ virtual ~PackMatrix() {
282
+ if (bufAllocatedHere_) {
283
+ fbgemmAlignedFree(buf_);
284
+ }
285
+ }
286
+
287
+ protected:
288
+ /**
289
+ * Set which block we're packing
290
+ */
291
+ void packedBlock(const block_type_t& block) {
292
+ packedBlock_ = block;
293
+ nbrow_ = (numPackedRows() + blockRowSize() - 1) / blockRowSize();
294
+ nbcol_ = (numPackedCols() + blockColSize() - 1) / blockColSize();
295
+
296
+ last_brow_ = ((numPackedRows() % blockRowSize()) == 0)
297
+ ? blockRowSize()
298
+ : (numPackedRows() % blockRowSize());
299
+ last_bcol_ = ((numPackedCols() % blockColSize()) == 0)
300
+ ? blockColSize()
301
+ : (numPackedCols() % blockColSize());
302
+ }
303
+
304
+ inpType* buf_;
305
+ std::int32_t brow_; ///< the number of rows in each block
306
+ std::int32_t bcol_; ///< the number of columns in each block
307
+ std::int32_t nbrow_; ///< the number of blocks along rows
308
+ std::int32_t nbcol_; ///< the number of blocks along columns
309
+ bool bufAllocatedHere_{false};
310
+ const BlockingFactors*
311
+ blocking_params; ///< MCB, KCB, NCB, MR, NR, NR_MIN, ROW_INTERLEAVE;
312
+
313
+ private:
314
+ std::int32_t nrows_, ncols_;
315
+ int G_;
316
+ block_type_t packedBlock_; ///< The block in the source matrix just packed
317
+ std::int32_t last_brow_, last_bcol_;
318
+ };
319
+
320
+ /**
321
+ * @brief Matrix packed for the first input matrix in GEMM (usually
322
+ * activation). The source matrix is already quantized. Default
323
+ * accumulation type is int32.
324
+ */
325
+ template <typename T, typename accT = std::int32_t>
326
+ class FBGEMM_API PackAMatrix final
327
+ : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
328
+ public:
329
+ using This = PackAMatrix<T, accT>;
330
+ using BaseType = PackMatrix<This, T, accT>;
331
+ using inpType = T;
332
+ using accType = accT;
333
+
334
+ PackAMatrix() = delete; // no default constructor
335
+
336
+ PackAMatrix(
337
+ matrix_op_t trans,
338
+ std::int32_t nRow,
339
+ std::int32_t nCol,
340
+ const inpType* smat,
341
+ std::int32_t ld,
342
+ inpType* pmat = nullptr,
343
+ int groups = 1,
344
+ const BlockingFactors* params = nullptr);
345
+
346
+ /**
347
+ * Activation matrices are not constant so cannot amortize the cost of
348
+ * pre-packing.
349
+ */
350
+ bool isPrePacked() const {
351
+ return false;
352
+ }
353
+
354
+ /**
355
+ * @return True if this is used as A matrix.
356
+ */
357
+ static constexpr bool isA() {
358
+ return true;
359
+ }
360
+
361
+ /**
362
+ * @return A pointer to the row offset buffer. There is no row offset buffer
363
+ * calculations with this packing class, hence, it returns nullptr.
364
+ */
365
+ std::int32_t* getRowOffsetBuffer() const {
366
+ return nullptr;
367
+ }
368
+
369
+ /**
370
+ * @return Offset of the element in the packed matrix that was at (i, j) in
371
+ * the source matrix.
372
+ */
373
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
374
+
375
+ /**
376
+ * @brief Packs a block of source matrix into pmat buffer.
377
+ */
378
+ void pack(const block_type_t& block);
379
+
380
+ /**
381
+ * @brief Print the packed block.
382
+ */
383
+ void printPackedMatrix(const std::string& name);
384
+
385
+ private:
386
+ matrix_op_t trans_;
387
+ const T* smat_;
388
+ std::int32_t ld_;
389
+ std::int32_t row_interleave_B_;
390
+ };
391
+
392
+ /**
393
+ * @brief Matrix packed for the second input matrix in GEMM (usually weight).
394
+ * The source matrix is already quantized. Default accumulation
395
+ * type is int32.
396
+ */
397
+ template <typename T, typename accT = std::int32_t>
398
+ class FBGEMM_API PackBMatrix final
399
+ : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
400
+ public:
401
+ using This = PackBMatrix<T, accT>;
402
+ using BaseType = PackMatrix<This, T, accT>;
403
+ using inpType = T;
404
+ using accType = accT;
405
+
406
+ PackBMatrix() = delete; // no default constructor
407
+
408
+ /**
409
+ * @param groups if > 1 and trans == NoTranspose, smat is nRow x nCol with
410
+ * groups are vertically concatenated: each group is
411
+ * (nRow / groups) x nCol .
412
+ * if > 1 and trans == Transpose, smat is (nCol * groups) x
413
+ * (nRow / groups) with groups are horizontally concatenated:
414
+ * each group is nCol x (nRow / groups) . Each group is
415
+ * transposed and vertically concatenated to match with the
416
+ * NoTranspose case.
417
+ */
418
+ PackBMatrix(
419
+ matrix_op_t trans,
420
+ std::int32_t nRow,
421
+ std::int32_t nCol,
422
+ const inpType* smat,
423
+ std::int32_t ld,
424
+ inpType* pmat = nullptr,
425
+ int groups = 1,
426
+ const BlockingFactors* params = nullptr);
427
+
428
+ /**
429
+ * Weight matrices are usually constant so worth pre-packing.
430
+ */
431
+ bool isPrePacked() const {
432
+ return true;
433
+ }
434
+
435
+ /**
436
+ * @return True if to be used as A matrix, False otherwise.
437
+ */
438
+ static constexpr bool isA() {
439
+ return false;
440
+ }
441
+
442
+ /**
443
+ * @brief When k loop is also tiled/blocked, this function is used to check if
444
+ * have executed computations for the last k block so that we can perform
445
+ * post-GEMM operations.
446
+ */
447
+ bool isThisLastKBlock(int block_id) const {
448
+ return (BaseType::blockRows() - 1) == block_id;
449
+ }
450
+
451
+ /**
452
+ * @return Offset of the element in the packed matrix that was at (i, j) in
453
+ * the source matrix.
454
+ */
455
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
456
+
457
+ /**
458
+ * @brief Packs a block of source matrix into pmat buffer. The blocking
459
+ * parameters are needed to compute the buffer size of each group.
460
+ * It will use default blocking parameters if params is not provided.
461
+ */
462
+ void pack(const block_type_t& block, const BlockingFactors* params = nullptr);
463
+
464
+ /**
465
+ * @brief Print the packed block.
466
+ */
467
+ void printPackedMatrix(
468
+ const std::string& name,
469
+ const BlockingFactors* params = nullptr);
470
+
471
+ /**
472
+ * @return true if meta information like matrix shape is the same.
473
+ */
474
+ bool metaEquals(const PackBMatrix<T, accT>& that) const;
475
+ /**
476
+ * @return true if matrices are the same.
477
+ */
478
+ bool equals(const PackBMatrix<T, accT>& that) const;
479
+
480
+ /**
481
+ * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to
482
+ * recover weight matrix).
483
+ */
484
+ void unpack(T* origin_buf, const BlockingFactors* params = nullptr);
485
+
486
+ ~PackBMatrix() override = default;
487
+
488
+ private:
489
+ matrix_op_t trans_;
490
+ const T* smat_;
491
+ std::int32_t ld_;
492
+ std::int32_t row_interleave_;
493
+
494
+ /**
495
+ * @brief Internal function performing both pack & unpack
496
+ */
497
+ void pack_unpack_(
498
+ const block_type_t& block,
499
+ T* unpack_buf,
500
+ T* pack_buf,
501
+ bool ispack,
502
+ const BlockingFactors* params = nullptr);
503
+ };
504
+
505
+ /**
506
+ * @brief Matrix packed for direct group convolution.
507
+ * The source matrix is already quantized. Default accumulation
508
+ * type is int32.
509
+ */
510
+ template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
511
+ class FBGEMM_API PackWeightMatrixForGConv {
512
+ public:
513
+ using This = PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>;
514
+ using inpType = T;
515
+ using accType = accT;
516
+
517
+ PackWeightMatrixForGConv() = delete; // no default constructor
518
+ PackWeightMatrixForGConv(const PackWeightMatrixForGConv&) = delete; // no copy
519
+ PackWeightMatrixForGConv& operator=(const PackWeightMatrixForGConv&) =
520
+ delete; // no copy
521
+
522
+ PackWeightMatrixForGConv(PackWeightMatrixForGConv&&) = delete; // no move
523
+ PackWeightMatrixForGConv& operator=(PackWeightMatrixForGConv&&) =
524
+ delete; // no move
525
+
526
+ /**
527
+ * @param pmat if nullptr, a buffer is allocated and owned by this class.
528
+ */
529
+ PackWeightMatrixForGConv(
530
+ matrix_op_t trans,
531
+ const conv_param_t<SPATIAL_DIM>& conv_param,
532
+ const inpType* sdata,
533
+ inpType* pdata = nullptr);
534
+
535
+ /**
536
+ * Number of groups we work at a time to fill the full simd width
537
+ * e.g., IC_PER_G = 4 and OC_PER_G = 4, we work on two groups at a time
538
+ * to fill the avx2 width of 256 bits.
539
+ */
540
+ static int numOfGroupsTogether(const conv_param_t<SPATIAL_DIM>& conv_param);
541
+
542
+ /**
543
+ * @brief Packs a block of source matrix into pmat buffer.
544
+ */
545
+ void pack();
546
+
547
+ /**
548
+ * @brief Unpacks a pmat buffer into source matrix.
549
+ */
550
+ void unpack(T* origin_buf);
551
+
552
+ /**
553
+ * @brief Return packed data
554
+ */
555
+ inpType* getBuf() {
556
+ return pdata_;
557
+ }
558
+
559
+ ~PackWeightMatrixForGConv() {
560
+ if (bufAllocatedHere_) {
561
+ fbgemmAlignedFree(pdata_);
562
+ }
563
+ }
564
+
565
+ private:
566
+ matrix_op_t trans_;
567
+ const conv_param_t<SPATIAL_DIM> conv_param_;
568
+ const T* sdata_;
569
+ T* pdata_;
570
+ bool bufAllocatedHere_{false};
571
+ // Number of groups we work at a time to fill the full simd width
572
+ int GTogether_;
573
+
574
+ /**
575
+ * @brief Internal function performing both pack & unpack
576
+ */
577
+ void pack_unpack_(const T* src, T* dst, bool ispack);
578
+
579
+ /**
580
+ * @brief Get the index of the unpacked data
581
+ */
582
+ int unpacked_index_(int t, int r, int s, int k, int g, int c, bool tr);
583
+
584
+ /**
585
+ * @brief Get the index of the packed data
586
+ */
587
+ int packed_index_(int t, int r, int s, int k, int g, int c);
588
+ };
589
+
590
+ /**
591
+ * @brief A container class to keep packed weight tensor for convolution.
592
+ * The source tensor should already be quantized.
593
+ *
594
+ * @tparam SPATIAL_DIM is equal to 2 for 2D convolutions and 3 for 3D
595
+ * convolutions. Default value is 2.
596
+ * @tparam T is the datatype for source tensor. Default value is int8.
597
+ * @tparam accT is the datatype to accumulate into. Default value is int32.
598
+ */
599
+ template <
600
+ int SPATIAL_DIM = 2,
601
+ typename T = std::int8_t,
602
+ typename accT = std::int32_t>
603
+ class FBGEMM_API PackWeightsForConv {
604
+ public:
605
+ using This = PackWeightsForConv<SPATIAL_DIM, T, accT>;
606
+ using inpType = T;
607
+ using accType = accT;
608
+
609
+ PackWeightsForConv() = delete; // no default constructor
610
+
611
+ PackWeightsForConv(
612
+ const conv_param_t<SPATIAL_DIM>& conv_param,
613
+ const inpType* sdata,
614
+ const BlockingFactors* blocking_params = nullptr);
615
+
616
+ std::shared_ptr<PackBMatrix<T, accT>> getPackedWForIm2col() {
617
+ return W_im2col_packed_;
618
+ }
619
+
620
+ #if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
621
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
622
+ return W_dw_packed_;
623
+ }
624
+ #endif // __aarch64__
625
+
626
+ std::shared_ptr<PackedDirectConvMatrix> getPackedWForDirectconv() {
627
+ return W_dc_packed_;
628
+ }
629
+
630
+ std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
631
+ getPackedWForGroupwise() {
632
+ return W_gconv_packed_;
633
+ }
634
+
635
+ std::shared_ptr<PackBMatrix<T, accT>> getPackedWForPointwise() {
636
+ return W_pointwise_packed_;
637
+ }
638
+
639
+ int inputChannels() {
640
+ return conv_param_.IC;
641
+ }
642
+
643
+ int outputChannels() {
644
+ return conv_param_.OC;
645
+ }
646
+
647
+ std::array<int, SPATIAL_DIM> kernelDims() {
648
+ return conv_param_.K;
649
+ }
650
+
651
+ int groups() {
652
+ return conv_param_.G;
653
+ }
654
+
655
+ /**
656
+ * @brief Returns true if the packed weights would work for the given
657
+ * convolution parameters, and false otherwise
658
+ */
659
+ bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p);
660
+
661
+ /**
662
+ * @brief Returns a string of mismatching parameters
663
+ */
664
+ std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p);
665
+
666
+ /**
667
+ * @brief Unpack packed matric into origin_buf (Used for the serialization to
668
+ * recover weight matrix).
669
+ */
670
+ void unpack(T* origin_buf);
671
+
672
+ private:
673
+ const conv_param_t<SPATIAL_DIM> conv_param_;
674
+ // Packed weights if we use im2col based convolution implementation
675
+ std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
676
+ #if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
677
+ // Packed weights if we use depthwise convolution implementation
678
+ std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
679
+ #endif // __aarch64__
680
+ // Packed weights if we use direct convolution implementation
681
+ std::shared_ptr<PackedDirectConvMatrix> W_dc_packed_;
682
+ // Packed weights if we use groupwise (small channels per group) convolution
683
+ // implementation
684
+ std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
685
+ W_gconv_packed_;
686
+ // Packed weights if we use direct gemm for pointwise convolution
687
+ std::shared_ptr<PackBMatrix<T, accT>> W_pointwise_packed_;
688
+ };
689
+
690
+ /**
691
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
692
+ * and row offsets used for requantization is computed during packing.
693
+ * Im2col is fused with packing here. The source matrix is already
694
+ * quantized.
695
+ */
696
+ template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
697
+ class FBGEMM_API PackAWithIm2Col
698
+ : public PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT> {
699
+ public:
700
+ using This = PackAWithIm2Col<T, accT, SPATIAL_DIM>;
701
+ using BaseType = PackMatrix<This, T, accT>;
702
+ using inpType = T;
703
+ using accType = accT;
704
+
705
+ PackAWithIm2Col() = delete; // no default constructor
706
+ /**
707
+ * @param zero_pt the quantized value that maps to 0.0f floating-point number.
708
+ * @param row_offset If nullptr, this constructor internally allocates a
709
+ * buffer and owns it. Otherwise, this class doesn't own
710
+ * the buffer. The buffer will be populated when pack
711
+ * function is called.
712
+ * @param b_symmetric if true we skip row offset computation
713
+ */
714
+ PackAWithIm2Col(
715
+ const conv_param_t<SPATIAL_DIM>& conv_param,
716
+ const T* sdata,
717
+ inpType* pmat = nullptr,
718
+ std::int32_t a_zero_pt = 0,
719
+ std::int32_t* row_offset = nullptr,
720
+ bool b_symmetric = false,
721
+ const BlockingFactors* params = nullptr);
722
+
723
+ PackAWithIm2Col(const PackAWithIm2Col&) = delete;
724
+ PackAWithIm2Col(PackAWithIm2Col&&) = delete;
725
+ PackAWithIm2Col& operator=(const PackAWithIm2Col&) = delete;
726
+ PackAWithIm2Col& operator=(PackAWithIm2Col&&) = delete;
727
+
728
+ /**
729
+ * Activation matrices are not constant so cannot amortize the cost of
730
+ * pre-packing.
731
+ */
732
+ bool isPrePacked() const {
733
+ return false;
734
+ }
735
+
736
+ /**
737
+ * @return True if this is used as A matrix.
738
+ */
739
+ static constexpr bool isA() {
740
+ return true;
741
+ }
742
+
743
+ /**
744
+ * @brief Packs a block of source matrix into pmat buffer.
745
+ */
746
+ void pack(const block_type_t& block);
747
+
748
+ /**
749
+ * @return A pointer to the row offset buffer.
750
+ */
751
+ std::int32_t* getRowOffsetBuffer() const {
752
+ return row_offset_;
753
+ }
754
+
755
+ /**
756
+ * @brief Print the packed block.
757
+ */
758
+ void printPackedMatrix(const std::string& name);
759
+
760
+ /**
761
+ * @return Size of row offset buffer in number of elements
762
+ */
763
+ static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
764
+
765
+ ~PackAWithIm2Col() override {
766
+ if (rowOffsetAllocatedHere) {
767
+ fbgemmAlignedFree(row_offset_);
768
+ }
769
+ }
770
+
771
+ private:
772
+ const conv_param_t<SPATIAL_DIM> conv_p_;
773
+ const T* sdata_;
774
+ std::int32_t a_zero_pt_;
775
+ std::int32_t* row_offset_{nullptr};
776
+ bool rowOffsetAllocatedHere{false};
777
+ std::int32_t row_interleave_B_;
778
+ };
779
+
780
+ /**
781
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
782
+ * and row offsets used for requantization is computed during packing.
783
+ * The source matrix is already quantized.
784
+ */
785
+ template <typename T, typename accT = std::int32_t>
786
+ class FBGEMM_API PackAWithRowOffset final
787
+ : public PackMatrix<PackAWithRowOffset<T, accT>, T, accT> {
788
+ public:
789
+ using This = PackAWithRowOffset<T, accT>;
790
+ using BaseType = PackMatrix<This, T, accT>;
791
+ using inpType = T;
792
+ using accType = accT;
793
+
794
+ PackAWithRowOffset() = delete; // no default constructor
795
+ /**
796
+ * @param row_offset If nullptr, this constructor internally allocates a
797
+ * buffer and owns it. Otherwise, this class doesn't own
798
+ * the buffer. The buffer will be populated when pack
799
+ * function is called.
800
+ */
801
+ PackAWithRowOffset(
802
+ matrix_op_t trans,
803
+ std::uint32_t nRow,
804
+ std::uint32_t nCol,
805
+ const T* smat,
806
+ std::uint32_t ld,
807
+ inpType* pmat = nullptr,
808
+ int groups = 1,
809
+ std::int32_t* row_offset = nullptr,
810
+ const BlockingFactors* params = nullptr);
811
+
812
+ PackAWithRowOffset(const PackAWithRowOffset&) = delete;
813
+ PackAWithRowOffset(PackAWithRowOffset&&) = delete;
814
+ PackAWithRowOffset& operator=(const PackAWithRowOffset&) = delete;
815
+ PackAWithRowOffset& operator=(PackAWithRowOffset&&) = delete;
816
+
817
+ /**
818
+ * Activation matrices are not constant so cannot amortize the cost of
819
+ * pre-packing.
820
+ */
821
+ bool isPrePacked() const {
822
+ return false;
823
+ }
824
+
825
+ /**
826
+ * @return True if this is used as A matrix.
827
+ */
828
+ static constexpr bool isA() {
829
+ return true;
830
+ }
831
+
832
+ /**
833
+ * @return Offset of the element in the packed matrix that was at (i, j) in
834
+ * the source matrix
835
+ */
836
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
837
+
838
+ /**
839
+ * @brief Packs a block of source matrix into pmat buffer.
840
+ */
841
+ void pack(const block_type_t& block);
842
+
843
+ /**
844
+ * @return A pointer to the row offset buffer.
845
+ */
846
+ std::int32_t* getRowOffsetBuffer() const {
847
+ return row_offset_;
848
+ }
849
+
850
+ /**
851
+ * @brief Print the packed block.
852
+ */
853
+ void printPackedMatrix(const std::string& name);
854
+
855
+ /**
856
+ * @return size of row offset buffer in number of elements
857
+ */
858
+ static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
859
+
860
+ ~PackAWithRowOffset() override {
861
+ if (rowOffsetAllocatedHere) {
862
+ fbgemmAlignedFree(row_offset_);
863
+ }
864
+ }
865
+
866
+ private:
867
+ matrix_op_t trans_;
868
+ const T* smat_;
869
+ std::uint32_t ld_;
870
+ std::int32_t* row_offset_{nullptr};
871
+ bool rowOffsetAllocatedHere{false};
872
+ std::int32_t row_interleave_B_;
873
+ };
874
+
875
+ /**
876
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
877
+ * and row offsets used for requantization is computed during packing.
878
+ * The source matrix is in fp32 and quantized during packing.
879
+ */
880
+ template <typename T, typename accT = std::int32_t>
881
+ class FBGEMM_API PackAWithQuantRowOffset final
882
+ : public PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT> {
883
+ public:
884
+ using This = PackAWithQuantRowOffset<T, accT>;
885
+ using BaseType = PackMatrix<This, T, accT>;
886
+ using inpType = T;
887
+ using accType = accT;
888
+
889
+ PackAWithQuantRowOffset() = delete; // no default constructor
890
+ /**
891
+ * @param row_offset If nullptr, this constructor internally allocates a
892
+ * buffer and owns it. Otherwise, this class doesn't own
893
+ * the buffer. The buffer will be populated when pack
894
+ * function is called.
895
+ */
896
+ PackAWithQuantRowOffset(
897
+ matrix_op_t trans,
898
+ std::int32_t nRow,
899
+ std::int32_t nCol,
900
+ const float* smat,
901
+ std::int32_t ld,
902
+ inpType* pmat = nullptr,
903
+ float scale = 1.0f,
904
+ std::int32_t zero_pt = 0,
905
+ int groups = 1,
906
+ std::int32_t* row_offset = nullptr,
907
+ const BlockingFactors* params = nullptr);
908
+ PackAWithQuantRowOffset(const PackAWithQuantRowOffset&) = delete;
909
+ PackAWithQuantRowOffset(PackAWithQuantRowOffset&&) = delete;
910
+ PackAWithQuantRowOffset& operator=(const PackAWithQuantRowOffset&) = delete;
911
+ PackAWithQuantRowOffset& operator=(PackAWithQuantRowOffset&&) = delete;
912
+
913
+ /**
914
+ * Activation matrices are not constant so cannot amortize the cost of
915
+ * pre-packing.
916
+ */
917
+ bool isPrePacked() const {
918
+ return false;
919
+ }
920
+
921
+ /**
922
+ * @return True if this is used as A matrix.
923
+ */
924
+ static constexpr bool isA() {
925
+ return true;
926
+ }
927
+
928
+ /**
929
+ * @return offset of the element in the packed matrix that was at (i, j) in
930
+ * the source matrix
931
+ */
932
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
933
+
934
+ /**
935
+ * @brief Packs a block of source matrix into pmat buffer.
936
+ */
937
+ void pack(const block_type_t& block);
938
+
939
+ /**
940
+ * @return A pointer to the row offset buffer.
941
+ */
942
+ std::int32_t* getRowOffsetBuffer() const {
943
+ return row_offset_;
944
+ }
945
+
946
+ /**
947
+ * @brief Print the packed block.
948
+ */
949
+ void printPackedMatrix(const std::string& name);
950
+
951
+ /**
952
+ * @return Size of row offset buffer in number of elements
953
+ */
954
+ static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
955
+
956
+ ~PackAWithQuantRowOffset() override {
957
+ if (rowOffsetAllocatedHere) {
958
+ fbgemmAlignedFree(row_offset_);
959
+ }
960
+ }
961
+
962
+ private:
963
+ matrix_op_t trans_;
964
+ const float* smat_;
965
+ std::int32_t ld_;
966
+ float scale_;
967
+ std::int32_t zero_pt_;
968
+ std::int32_t* row_offset_{nullptr};
969
+ bool rowOffsetAllocatedHere{false};
970
+ std::int32_t row_interleave_B_;
971
+ };
972
+
973
+ /*
974
+ *
975
+ * Post Processing of outputs
976
+ *
977
+ */
978
+
979
+ /**
980
+ * @brief Does nothing. NoOp. Used as the last operation in the output
981
+ * processing pipeline.
982
+ *
983
+ */
984
+ template <typename outT = std::uint8_t, typename inT = std::uint8_t>
985
+ class FBGEMM_API DoNothing {
986
+ public:
987
+ using outType = outT;
988
+ using inpType = inT;
989
+ DoNothing() = default;
990
+ template <inst_set_t instSet>
991
+ int f(
992
+ outType* /* unused */,
993
+ inpType* /* unused */,
994
+ const block_type_t& /* unused */,
995
+ int /* unused */,
996
+ int /* unused */) const {
997
+ return 0;
998
+ }
999
+ };
1000
+
1001
+ /**
1002
+ * @brief Copy data pointed by inp ptr to out ptr when
1003
+ * inp ptr and out ptr are not the same.
1004
+ * inp buffer: row and column start points: (0, 0)
1005
+ * output buffer: row and column start points:
1006
+ * (block.row_start, block.col_start)
1007
+ *
1008
+ * This is the output processing stage that should passed when there is no
1009
+ * requantization and output is required in the same format as internal buffer
1010
+ * used for accumulation.
1011
+ */
1012
+ template <
1013
+ typename outT = std::int32_t,
1014
+ typename inT = std::int32_t,
1015
+ typename nextOPType = DoNothing<outT, outT>>
1016
+ class FBGEMM_API memCopy {
1017
+ public:
1018
+ using outType = outT;
1019
+ using inpType = inT;
1020
+ explicit memCopy(nextOPType& nextop) : nextop_(nextop) {}
1021
+ template <inst_set_t instSet>
1022
+ inline int f(
1023
+ outType* out,
1024
+ inpType* inp,
1025
+ const block_type_t& block,
1026
+ int ld_out,
1027
+ int ld_in) const;
1028
+
1029
+ private:
1030
+ nextOPType& nextop_;
1031
+ };
1032
+
1033
+ /**
1034
+ * @brief Perform scaling on accumulated data.
1035
+ */
1036
+ template <
1037
+ typename outT = std::int32_t,
1038
+ typename inT = std::int32_t,
1039
+ typename nextOPType = DoNothing<outT, outT>>
1040
+ class ScaleOP {
1041
+ public:
1042
+ using outType = outT;
1043
+ using inpType = inT;
1044
+ explicit ScaleOP(inpType scalingFactor) : scalingFactor_(scalingFactor) {}
1045
+
1046
+ template <inst_set_t instSet>
1047
+ inline int f(
1048
+ outType* out,
1049
+ inpType* inp,
1050
+ const block_type_t& block,
1051
+ int ld_out,
1052
+ int ld_in) const;
1053
+
1054
+ private:
1055
+ inpType scalingFactor_;
1056
+ };
1057
+
1058
+ /**
1059
+ * @brief Perform Relu on accumulated data.
1060
+ */
1061
+ template <
1062
+ typename outT = std::int32_t,
1063
+ typename inT = std::int32_t,
1064
+ typename nextOPType = DoNothing<outT, outT>>
1065
+ class ReluOutput {
1066
+ public:
1067
+ using outType = outT;
1068
+ using inpType = inT;
1069
+ explicit ReluOutput(inpType zero_pt) : zero_pt_(zero_pt) {}
1070
+
1071
+ template <inst_set_t instSet>
1072
+ inline int f(
1073
+ outType* out,
1074
+ inpType* inp,
1075
+ const block_type_t& block,
1076
+ int ld_out,
1077
+ int ld_in) const;
1078
+
1079
+ private:
1080
+ inpType zero_pt_;
1081
+ };
1082
+
1083
+ /**
1084
+ * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
1085
+ * processing pipeline.
1086
+ *
1087
+ * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
1088
+ * (inp). After modifying the input buffer, pass it to the next op.
1089
+ * When groups > 1, each group is numRows() x (numCols()/groups) matrix.
1090
+ */
1091
+ template <
1092
+ typename outT = std::int32_t,
1093
+ typename inT = std::int32_t,
1094
+ typename nextOPType = DoNothing<inT, inT>>
1095
+ class FBGEMM_API DoSpmdmOnInpBuffer {
1096
+ public:
1097
+ using outType = outT;
1098
+ using inpType = inT;
1099
+ DoSpmdmOnInpBuffer(
1100
+ nextOPType& nextop,
1101
+ const std::uint8_t* A,
1102
+ int lda,
1103
+ const CompressedSparseColumn& B_csc,
1104
+ int groups = 1)
1105
+ : nextop_(nextop), A_(A), lda_(lda), B_csc_(B_csc), groups_(groups) {}
1106
+
1107
+ template <inst_set_t instSet>
1108
+ inline int f(
1109
+ outT* out,
1110
+ inT* inp,
1111
+ const block_type_t& block,
1112
+ int ld_out,
1113
+ int ld_in) const;
1114
+
1115
+ private:
1116
+ nextOPType& nextop_;
1117
+ const std::uint8_t* A_;
1118
+ const int lda_;
1119
+ const CompressedSparseColumn& B_csc_;
1120
+ const int groups_;
1121
+ };
1122
+
1123
+ /**
1124
+ * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
1125
+ * processing pipeline.
1126
+ *
1127
+ * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
1128
+ * (inp). After modifying the input buffer, pass it to the next op.
1129
+ * When groups > 1, each group is numRows() x (numCols()/groups) matrix.
1130
+ */
1131
+ template <
1132
+ typename outT = std::int32_t,
1133
+ typename inT = std::int32_t,
1134
+ typename nextOPType = DoNothing<inT, inT>>
1135
+ class FBGEMM_API DoSConvOnInpBuffer {
1136
+ public:
1137
+ using outType = outT;
1138
+ using inpType = inT;
1139
+ DoSConvOnInpBuffer(
1140
+ nextOPType& nextop,
1141
+ const std::uint8_t* A,
1142
+ const conv_param_t<>& conv_p,
1143
+ std::int32_t A_zero_point,
1144
+ const CompressedSparseColumn& B_csc)
1145
+ : nextop_(nextop),
1146
+ A_(A),
1147
+ conv_p_(conv_p),
1148
+ A_zero_point_(A_zero_point),
1149
+ B_csc_(B_csc) {}
1150
+
1151
+ template <inst_set_t instSet>
1152
+ inline int f(
1153
+ outT* out,
1154
+ inT* inp,
1155
+ const block_type_t& block,
1156
+ int ld_out,
1157
+ int ld_in) const;
1158
+
1159
+ private:
1160
+ nextOPType& nextop_;
1161
+ const std::uint8_t* A_;
1162
+ const conv_param_t<> conv_p_;
1163
+ const std::int32_t A_zero_point_;
1164
+ const CompressedSparseColumn& B_csc_;
1165
+ };
1166
+
1167
+ /**
1168
+ * @brief Requantize values in inp buffer and write to out buffer.
1169
+ * pass the out buffer to next op for further processing.
1170
+ */
1171
+ template <
1172
+ bool FUSE_RELU,
1173
+ QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
1174
+ typename BIAS_TYPE = std::int32_t,
1175
+ typename outT = std::uint8_t,
1176
+ typename inT = std::int32_t,
1177
+ typename nextOPType = DoNothing<outT, outT>>
1178
+ class FBGEMM_API ReQuantizeOutput {
1179
+ public:
1180
+ static constexpr int RELU_FUSED = FUSE_RELU;
1181
+ static constexpr QuantizationGranularity QGRANType = Q_GRAN;
1182
+ using BIAS_T = BIAS_TYPE;
1183
+ using outType = outT;
1184
+ using inpType = inT;
1185
+ /**
1186
+ * @param C_multiplier The length of this array is
1187
+ * 1 when Q_GRAN == QuantizationGranularity::TENSOR,
1188
+ * groups when Q_GRAN == QuantizationGranularity::GROUP,
1189
+ * nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL
1190
+ * @param Bq_zero_point The length of this array should be the same as
1191
+ * C_multiplier.
1192
+ * @param row_offsets Typically, this should've been computed by a
1193
+ * PackAMatrix and should be obtained by
1194
+ * PackMatrix::getRowOffsetBuffer().
1195
+ * If Bq_zero_point == 0 (symmetric quantization of B
1196
+ * matrix), we can pass nullptr.
1197
+ * @param col_offsets This should be pre-computed for example using
1198
+ * col_offsets_with_zero_pt_s8acc32_ref.
1199
+ * The length should be nCol.
1200
+ * See PackedRequantizeTest.cc for an example.
1201
+ * TODO: if Aq_zero_point == 0, allow passing nullptr.
1202
+ * @param bias can be nullptr otherwise the length should be nCol
1203
+ * @param act_times_w_scale activation_scale * weight_scale. This is only
1204
+ * used if bias is unquantized (i.e., float).
1205
+ */
1206
+ ReQuantizeOutput(
1207
+ nextOPType& nextop,
1208
+ const float* C_multiplier,
1209
+ std::int32_t C_zero_point,
1210
+ std::int32_t Aq_zero_point,
1211
+ const std::int32_t* Bq_zero_point,
1212
+ const std::int32_t* row_offsets,
1213
+ const std::int32_t* col_offsets,
1214
+ const BIAS_T* bias,
1215
+ std::uint32_t nCol,
1216
+ int groups = 1,
1217
+ const float* act_times_w_scale = nullptr)
1218
+ : nextop_(nextop),
1219
+ C_multiplier_(C_multiplier),
1220
+ C_zero_point_(C_zero_point),
1221
+ Aq_zero_point_(Aq_zero_point),
1222
+ Bq_zero_point_(Bq_zero_point),
1223
+ q_row_offsets_(row_offsets),
1224
+ q_col_offsets_(col_offsets),
1225
+ bias_(bias),
1226
+ ncols_(nCol),
1227
+ groups_(groups),
1228
+ act_times_w_scale_(act_times_w_scale) {}
1229
+
1230
+ template <inst_set_t instSet>
1231
+ inline int f(
1232
+ outT* out,
1233
+ const inT* inp,
1234
+ const block_type_t& block,
1235
+ int ld_out,
1236
+ int ld_in) const;
1237
+
1238
+ const float* getCMultiplier() const {
1239
+ return C_multiplier_;
1240
+ }
1241
+ std::int32_t getAZeroPoint() const {
1242
+ return Aq_zero_point_;
1243
+ }
1244
+ std::int32_t getCZeroPoint() const {
1245
+ return C_zero_point_;
1246
+ }
1247
+ const std::int32_t* getBZeroPoint() const {
1248
+ return Bq_zero_point_;
1249
+ }
1250
+ const std::int32_t* getRowOffsets() const {
1251
+ return q_row_offsets_;
1252
+ }
1253
+ const std::int32_t* getColOffsets() const {
1254
+ return q_col_offsets_;
1255
+ }
1256
+ const BIAS_T* getBias() const {
1257
+ return bias_;
1258
+ }
1259
+ std::uint32_t getNCols() const {
1260
+ return ncols_;
1261
+ }
1262
+ const float* getActWScale() const {
1263
+ return act_times_w_scale_;
1264
+ }
1265
+
1266
+ void setRowOffsets(const std::int32_t* row_offsets) {
1267
+ q_row_offsets_ = row_offsets;
1268
+ }
1269
+
1270
+ private:
1271
+ nextOPType& nextop_;
1272
+ const float* C_multiplier_;
1273
+ std::int32_t C_zero_point_;
1274
+ std::int32_t Aq_zero_point_;
1275
+ const std::int32_t* Bq_zero_point_;
1276
+ const std::int32_t* q_row_offsets_;
1277
+ const std::int32_t* q_col_offsets_;
1278
+ const BIAS_T* bias_;
1279
+ std::uint32_t ncols_;
1280
+ int groups_;
1281
+ const float* act_times_w_scale_;
1282
+ };
1283
+
1284
+ /**
1285
+ * @brief Requantize to convert accumulated data to be used as float, i.e., the
1286
+ * output would be used as float.
1287
+ */
1288
+ template <
1289
+ bool FUSE_RELU,
1290
+ QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
1291
+ typename outT = float,
1292
+ typename inT = std::int32_t,
1293
+ typename nextOPType = DoNothing<outT, outT>>
1294
+ class FBGEMM_API ReQuantizeForFloat {
1295
+ public:
1296
+ using outType = outT;
1297
+ using inpType = inT;
1298
+ /**
1299
+ * @param Bq_scale The length of this array is
1300
+ * 1 when Q_GRAN == QuantizationGranularity::TENSOR,
1301
+ * groups when Q_GRAN == QuantizationGranularity::GROUP,
1302
+ * nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL
1303
+ * @param Bq_zero_point The length of this array should be the same as
1304
+ * Bq_scale.
1305
+ * @param row_offsets Typically, this should've been computed by a
1306
+ * PackAMatrix and should be obtained by
1307
+ * PackMatrix::getRowOffsetBuffer().
1308
+ * If Bq_zero_point == 0 (symmetric quantization of B
1309
+ * matrix), we can pass nullptr.
1310
+ * @param col_offsets This should be pre-computed for example using
1311
+ * col_offsets_with_zero_pt_s8acc32_ref.
1312
+ * The length should be nCol.
1313
+ * See PackedRequantizeTest.cc for an example.
1314
+ * TODO: if Aq_zero_point == 0, allow passing nullptr.
1315
+ * @param bias can be nullptr otherwise the length should be nCol
1316
+ */
1317
+ ReQuantizeForFloat(
1318
+ nextOPType& nextop,
1319
+ float Aq_scale,
1320
+ const float* Bq_scale,
1321
+ std::int32_t Aq_zero_point,
1322
+ const std::int32_t* Bq_zero_point,
1323
+ const std::int32_t* row_offsets,
1324
+ const std::int32_t* col_offsets,
1325
+ const float* bias,
1326
+ std::uint32_t nCol,
1327
+ int groups = 1)
1328
+ : nextop_(nextop),
1329
+ Aq_scale_(Aq_scale),
1330
+ Bq_scale_(Bq_scale),
1331
+ Aq_zero_point_(Aq_zero_point),
1332
+ Bq_zero_point_(Bq_zero_point),
1333
+ q_row_offsets_(row_offsets),
1334
+ q_col_offsets_(col_offsets),
1335
+ bias_(bias),
1336
+ ncols_(nCol),
1337
+ groups_(groups) {}
1338
+
1339
+ template <inst_set_t instSet>
1340
+ inline int f(
1341
+ outT* out,
1342
+ inT* inp,
1343
+ const block_type_t& block,
1344
+ int ld_out,
1345
+ int ld_in) const;
1346
+
1347
+ private:
1348
+ nextOPType& nextop_;
1349
+ float Aq_scale_;
1350
+ const float* Bq_scale_;
1351
+ std::int32_t Aq_zero_point_;
1352
+ const std::int32_t* Bq_zero_point_;
1353
+ const std::int32_t* q_row_offsets_;
1354
+ const std::int32_t* q_col_offsets_;
1355
+ const float* bias_;
1356
+ std::uint32_t ncols_;
1357
+ int groups_;
1358
+ };
1359
+
1360
+ // type specialized implementation in an include file
1361
+ #include "./OutputProcessing-inl.h" // @manual
1362
+
1363
+ /*
1364
+ *
1365
+ * ####### GEMM related functions #######
1366
+ *
1367
+ */
1368
+
1369
+ /**
1370
+ * Matrix B must be prepacked. For matrix A, packA.pack function is called to
1371
+ * pack it.
1372
+ *
1373
+ * @tparam packingAMatrix processing of A matrix while packing,
1374
+ * e.g., PackAWithQuantRowOffset
1375
+ *
1376
+ * @tparam packingBMatrix processing of B matrix while packing,
1377
+ * e.g., pre-multiply by alpha
1378
+ * @tparam cT data type of C matrix
1379
+ * @tparam processOutputType further processing of outputs, e.g., Relu
1380
+ */
1381
+ template <
1382
+ typename packingAMatrix,
1383
+ typename packingBMatrix,
1384
+ typename cT,
1385
+ typename processOutputType>
1386
+ FBGEMM_API void fbgemmPacked(
1387
+ PackMatrix<
1388
+ packingAMatrix,
1389
+ typename packingAMatrix::inpType,
1390
+ typename packingAMatrix::accType>& packA,
1391
+ PackMatrix<
1392
+ packingBMatrix,
1393
+ typename packingBMatrix::inpType,
1394
+ typename packingBMatrix::accType>& packB,
1395
+ cT* C,
1396
+ std::int32_t* C_buffer,
1397
+ std::uint32_t ldc,
1398
+ const processOutputType& outProcess,
1399
+ int thread_id,
1400
+ int num_threads,
1401
+ const BlockingFactors* blocking_params = nullptr);
1402
+
1403
+ /**
1404
+ * @brief Perform small-channels-per-group groupwise convolution
1405
+ * Note: Currently threading is not supported. This function does
1406
+ * nothing for thread_ids > 0, i.e., returns early.
1407
+ *
1408
+ * @param rowOffsetBuf nullptr if B uses symmetric quantization
1409
+ * Note: Currently threading is not supported. This function does
1410
+ * nothing for thread_ids > 0, i.e., returns early.
1411
+ */
1412
+ template <
1413
+ typename packed_W,
1414
+ typename outType,
1415
+ bool FUSE_RELU,
1416
+ QuantizationGranularity Q_GRAN,
1417
+ int SPATIAL_DIM = 2,
1418
+ typename BIAS_TYPE = std::int32_t>
1419
+ FBGEMM_API void fbgemmGroupwiseConv(
1420
+ const conv_param_t<SPATIAL_DIM>& conv_param,
1421
+ const std::uint8_t* activations,
1422
+ std::int32_t a_zero_point,
1423
+ std::int32_t* rowOffsetBuf,
1424
+ packed_W& packed_weights,
1425
+ outType* out,
1426
+ std::int32_t* outBuffer,
1427
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
1428
+ int thread_id,
1429
+ int num_threads);
1430
+
1431
+ template <
1432
+ int SPATIAL_DIM,
1433
+ QuantizationGranularity Q_GRAN,
1434
+ bool FUSE_RELU,
1435
+ typename BIAS_TYPE = std::int32_t>
1436
+ FBGEMM_API void fbgemmDirectConv(
1437
+ const conv_param_t<SPATIAL_DIM>& conv_p,
1438
+ const uint8_t* Aint8,
1439
+ PackedDirectConvMatrix& Bint8_tr,
1440
+ uint8_t* C,
1441
+ int32_t* C_buffer,
1442
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
1443
+ const BIAS_TYPE* bias,
1444
+ int thread_id,
1445
+ int num_threads);
1446
+
1447
+ /**
1448
+ * @return Size of row offset buffer in number of elements needed for
1449
+ * fbgemmGroupwiseConv
1450
+ */
1451
+ template <int SPATIAL_DIM = 2>
1452
+ FBGEMM_API int rowOffsetBufferSizeGConv(
1453
+ const conv_param_t<SPATIAL_DIM>& conv_param);
1454
+
1455
+ /**
1456
+ * @brief Is this depthwise convolution optimized?
1457
+ */
1458
+ template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
1459
+ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
1460
+
1461
+ /**
1462
+ * @brief Is this groupwise convolution supported?
1463
+ */
1464
+ template <int SPATIAL_DIM>
1465
+ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p);
1466
+
1467
+ /**
1468
+ * @brief Is this convolution a direct matrix-matrix multiplication, i.e., 1x1
1469
+ * (aka pointwise) with right paddings etc.?
1470
+ */
1471
+ template <int SPATIAL_DIM>
1472
+ FBGEMM_API bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
1473
+
1474
+ /**
1475
+ * @brief Are we running on a fbgemm supported cpu?
1476
+ */
1477
+ FBGEMM_API bool fbgemmSupportedCPU();
1478
+
1479
+ /**
1480
+ * @brief Performs convolution using fastest path available.
1481
+ *
1482
+ * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
1483
+ */
1484
+ template <
1485
+ typename processOutputType,
1486
+ int SPATIAL_DIM = 2,
1487
+ typename ACC_T = std::int32_t>
1488
+ FBGEMM_API int fbgemmConv(
1489
+ const conv_param_t<SPATIAL_DIM>& conv_p,
1490
+ const std::uint8_t* activations,
1491
+ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
1492
+ typename processOutputType::outType* out,
1493
+ std::int32_t* outBuffer,
1494
+ processOutputType& outProcess,
1495
+ int thread_id,
1496
+ int num_threads,
1497
+ const BlockingFactors* blocking_params = nullptr);
1498
+
1499
+ /**
1500
+ * @brief Returns which fast path to take
1501
+ *
1502
+ * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
1503
+ *
1504
+ * @return optimized_conv_t::depthwise, optimized_conv_t::groupwise or
1505
+ * optimized_conv_t::im2col
1506
+ *
1507
+ */
1508
+ template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
1509
+ FBGEMM_API optimized_conv_t
1510
+ ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
1511
+ } // namespace fbgemm
1512
+
1513
+ #else
1514
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1515
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ // For details about dllexport/dllimport, checkout the following SO question
13
+ // https://stackoverflow.com/questions/57999/what-is-the-difference-between-dllexport-and-dllimport
14
+ #if !defined(FBGEMM_API)
15
+ #if defined(FBGEMM_STATIC)
16
+ #define FBGEMM_API
17
+ #define FBGEMM_ENUM_CLASS_API
18
+ #elif defined _WIN32 || defined __CYGWIN__
19
+ #if (__GNUC__ || __clang__) && !(__MINGW64__ || __MINGW32__)
20
+ #if defined(FBGEMM_EXPORTS)
21
+ #define FBGEMM_API __attribute__((__dllexport__))
22
+ #else
23
+ #define FBGEMM_API __attribute__((__dllimport__))
24
+ #endif
25
+ #else
26
+ #if defined(FBGEMM_EXPORTS)
27
+ #define FBGEMM_API __declspec(dllexport)
28
+ #else
29
+ #define FBGEMM_API __declspec(dllimport)
30
+ #endif
31
+ #endif
32
+ #define FBGEMM_ENUM_CLASS_API
33
+ #else
34
+ #if __clang__ || __GNUC__ || __INTEL_COMPILER
35
+ #define FBGEMM_API __attribute__((__visibility__("default")))
36
+ #else
37
+ #define FBGEMM_API
38
+ #endif
39
+ // Currently, enum classes need to be declaredly explicitly for shared build on
40
+ // macos
41
+ #if __clang__
42
+ #define FBGEMM_ENUM_CLASS_API __attribute__((__visibility__("default")))
43
+ #else
44
+ #define FBGEMM_ENUM_CLASS_API
45
+ #endif
46
+ #endif
47
+ #endif
48
+
49
+ // Use this to indicate to not inline functions
50
+ #if __clang__ || __GNUC__ || __INTEL_COMPILER
51
+ #define NOINLINE __attribute__((noinline))
52
+ #elif _MSC_VER
53
+ #define NOINLINE __declspec(noinline)
54
+ #else
55
+ #define NOINLINE
56
+ #endif
57
+
58
+ // Use this to indicate always inline functions
59
+ #if __clang__ || __GNUC__ || __INTEL_COMPILER
60
+ #define ALWAYS_INLINE inline __attribute__((__always_inline__))
61
+ #elif _MSC_VER
62
+ // commenting out because __forceinline takes too long time in MSVC
63
+ #define ALWAYS_INLINE // __forceinline
64
+ #else
65
+ #define ALWAYS_INLINE inline
66
+ #endif
67
+
68
+ // Use the C++11 keyword "alignas" if you can
69
+ #if _MSC_VER
70
+ #define ALIGNAS(byte_alignment) __declspec(align(byte_alignment))
71
+ #else
72
+ #define ALIGNAS(byte_alignment) __attribute__((aligned(byte_alignment)))
73
+ #endif
74
+
75
+ // Sanitizers annotations
76
+ #if defined(__has_attribute)
77
+ #if __has_attribute(no_sanitize)
78
+ #define NO_SANITIZE(what) __attribute__((no_sanitize(what)))
79
+ #endif
80
+ #endif
81
+ #if !defined(NO_SANITIZE)
82
+ #define NO_SANITIZE(what)
83
+ #endif
84
+
85
+ // Ignore __builtin_assume() when not supported by compiler.
86
+ #ifndef __has_builtin
87
+ #define __has_builtin(x) 0
88
+ #endif
89
+ #if !__has_builtin(__builtin_assume)
90
+ #define __builtin_assume(x) (static_cast<void>(0))
91
+ #endif
92
+
93
+ // Macro for silencing warnings
94
+ #if __clang__ || __GNUC__
95
+ // clang-format off
96
+ #define FBGEMM_PUSH_WARNING _Pragma("GCC diagnostic push")
97
+ #define FBGEMM_DISABLE_WARNING_INTERNAL2(warningName) #warningName
98
+ #define FBGEMM_DISABLE_WARNING(warningName) \
99
+ _Pragma( \
100
+ FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName))
101
+ #define FBGEMM_PUSH_WARNING_AND_DISABLE(warningName) \
102
+ _Pragma("GCC diagnostic push") \
103
+ _Pragma( \
104
+ FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName))
105
+ #define FBGEMM_POP_WARNING _Pragma("GCC diagnostic pop")
106
+ // clang-format on
107
+ #else
108
+ #define FBGEMM_PUSH_WARNING
109
+ #define FBGEMM_DISABLE_WARNING(NAME)
110
+ #define FBGEMM_PUSH_WARNING_AND_DISABLE(NAME)
111
+ #define FBGEMM_POP_WARNING
112
+ #endif
113
+
114
+ #else
115
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
116
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include <cstddef>
13
+ #include <cstdint>
14
+ #include "fbgemm/FbgemmBuild.h"
15
+ #include "fbgemm/Types.h"
16
+
17
+ namespace fbgemm {
18
+
19
+ /**
20
+ * @ Transform all entries in a matrix from fp32 to bfloat16: reference
21
+ * implementation.
22
+ *
23
+ */
24
+ FBGEMM_API void
25
+ FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);
26
+
27
+ /**
28
+ * @ Transform all entries in a matrix from bfloat16 to fp32: reference
29
+ * implementation.
30
+ *
31
+ */
32
+ FBGEMM_API void
33
+ Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);
34
+
35
+ /**
36
+ * @ Transform all entries in a matrix from fp32 to bfloat16: simd
37
+ * implementation.
38
+ *
39
+ */
40
+ FBGEMM_API void
41
+ FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);
42
+
43
+ /**
44
+ * @ Transform all entries in a matrix from bfloat16 to fp32: simd
45
+ * implementation.
46
+ *
47
+ */
48
+ FBGEMM_API void
49
+ Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);
50
+
51
+ #if !defined(__aarch64__)
52
+ /**
53
+ * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
54
+ *
55
+ */
56
+ FBGEMM_API void
57
+ FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);
58
+
59
+ /**
60
+ * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
61
+ *
62
+ */
63
+ FBGEMM_API void
64
+ FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);
65
+
66
+ /**
67
+ * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
68
+ *
69
+ */
70
+ FBGEMM_API void
71
+ Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);
72
+
73
+ /**
74
+ * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
75
+ *
76
+ */
77
+ FBGEMM_API void
78
+ Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);
79
+ #endif
80
+
81
+ /**
82
+ * @ Transform all entries in a matrix from fp32 to float16: reference
83
+ * implementation.
84
+ *
85
+ * @param do_clip if true we saturate to fp16 min and max instead of generating
86
+ * infinities.
87
+ */
88
+ FBGEMM_API void FloatToFloat16_ref(
89
+ const float* src,
90
+ float16* dst,
91
+ size_t size,
92
+ bool do_clip = false);
93
+
94
+ /**
95
+ * @ Transform all entries in a matrix from float16 to fp32: reference
96
+ * implementation.
97
+ *
98
+ */
99
+ FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);
100
+
101
+ /**
102
+ * @ Transform all entries in a matrix from fp32 to float16: simd
103
+ * implementation.
104
+ *
105
+ * @param do_clip if true we saturate to fp16 min and max instead of generating
106
+ * infinities.
107
+ */
108
+ FBGEMM_API void FloatToFloat16_simd(
109
+ const float* src,
110
+ float16* dst,
111
+ size_t size,
112
+ bool do_clip = false);
113
+
114
+ /**
115
+ * @ Transform all entries in a matrix from float16 to fp32: simd
116
+ * implementation.
117
+ *
118
+ */
119
+ FBGEMM_API void
120
+ Float16ToFloat_simd(const float16* src, float* dst, size_t size);
121
+
122
+ /**
123
+ * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
124
+ *
125
+ */
126
+ #if !defined(__aarch64__)
127
+ FBGEMM_API void FloatToFloat16_avx2(
128
+ const float* src,
129
+ float16* dst,
130
+ size_t size,
131
+ bool do_clip = false);
132
+
133
+ /**
134
+ * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
135
+ *
136
+ */
137
+ FBGEMM_API void FloatToFloat16_avx512(
138
+ const float* src,
139
+ float16* dst,
140
+ size_t size,
141
+ bool do_clip = false);
142
+ #endif
143
+
144
+ /**
145
+ * @brief SVE2 implementation to convert fp32 numbers to fp16 numbers.
146
+ *
147
+ */
148
+ FBGEMM_API void FloatToFloat16_sve2(
149
+ const float* src,
150
+ float16* dst,
151
+ size_t size,
152
+ bool do_clip = false);
153
+
154
+ #if !defined(__aarch64__)
155
+ /**
156
+ * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
157
+ *
158
+ */
159
+ FBGEMM_API void
160
+ Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
161
+
162
+ /**
163
+ * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
164
+ *
165
+ */
166
+ FBGEMM_API void
167
+ Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
168
+ #endif
169
+
170
+ /**
171
+ * @brief Transform all entries in a matrix from fp32 to float16 and back to
172
+ * fp32.
173
+ */
174
+ FBGEMM_API void RoundToFloat16(
175
+ const float* input,
176
+ float* output,
177
+ size_t size,
178
+ bool clamp = false,
179
+ bool clamp_denorms = false);
180
+
181
+ /**
182
+ * @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in
183
+ * fbgemm_gpu/quantize_ops_utils.h
184
+ */
185
+ FBGEMM_API void FloatToFloat8_ref(
186
+ float input,
187
+ uint8_t* output,
188
+ int exponent_bits,
189
+ int exponent_bias);
190
+
191
+ /**
192
+ * @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in
193
+ * fbgemm_gpu/quantize_ops_utils.h
194
+ */
195
+ FBGEMM_API void Float8ToFloat_ref(
196
+ uint8_t input,
197
+ float* output,
198
+ int exponent_bits,
199
+ int exponent_bias);
200
+
201
+ } // namespace fbgemm
202
+
203
+ #else
204
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
205
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+ #include <cstdint>
12
+ #include <functional>
13
+
14
+ #include "fbgemm/FbgemmBuild.h"
15
+
16
+ namespace fbgemm {
17
+
18
+ template <
19
+ typename InType,
20
+ typename IndexType,
21
+ typename OffsetType = std::int32_t,
22
+ typename OutType = float>
23
+ class EmbeddingSpMDMKernelSignature {
24
+ public:
25
+ /**
26
+ * Behavior is as the follow pseudocode
27
+ * (when use_offsets == true, lengths[i] == offsets[i + 1] - offsets[i])
28
+ * (when is_weight_positional == true, use weights[j - offsets[i]] instead of
29
+ * weights[j])
30
+ *
31
+ * for i in range(output_size):
32
+ * out[i * block_size : (i + 1) * block_size] = 0
33
+ * for j in range(offsets[i], offsets[i + 1]):
34
+ * for k in range(block_size):
35
+ * out[i * block_size + k] += input[indices[j] * block_size + k] *
36
+ * weights ? weights[j] : 1;
37
+ * if normalize_weights and lengths[i] > 0:
38
+ * out[i * block_size : (i + 1) * block_size] /= lengths[i]
39
+ *
40
+ * @param data_size the number of rows in embedding table
41
+ */
42
+ using Type = std::function<bool(
43
+ std::int64_t output_size,
44
+ std::int64_t index_size,
45
+ std::int64_t data_size,
46
+ const InType* input,
47
+ const IndexType* indices,
48
+ const OffsetType* offsets_or_lengths,
49
+ const float* weights, // optional, can be null for non-weighted sum
50
+ OutType* out)>;
51
+ };
52
+
53
+ /**
54
+ * @tparam InType can be float, float16, or uint8_t
55
+ * @tparam IndexType can be int32_t or int64_t
56
+ * @tparam IndexType can be int32_t or int64_t
57
+ *
58
+ * @param use_offsets If true, the generated code assumes we will pass offsets
59
+ * instead of lengths that confirms PyTorch EmbeddingBag
60
+ * interface. In this case, the length of offsets array
61
+ * should be output_size + 1 and offsets[output_size] should
62
+ * be index_size.
63
+ * If false, the generate code assumes we will pass lengths
64
+ * that confirms Caffe2 SparseLengthsSum interface.
65
+ */
66
+ template <
67
+ typename InType,
68
+ typename IndexType,
69
+ typename OffsetType = std::int32_t,
70
+ typename OutType = float,
71
+ bool THREAD_LOCAL = false>
72
+ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
73
+ InType,
74
+ IndexType,
75
+ OffsetType,
76
+ OutType>::Type
77
+ GenerateEmbeddingSpMDM(
78
+ const std::int64_t block_size,
79
+ bool has_weight,
80
+ bool normalize_by_lengths,
81
+ int prefetch = 16,
82
+ bool is_weight_positional = false,
83
+ bool use_offsets = true,
84
+ bool is_bf16_out = false,
85
+ bool is_bf16_in = false);
86
+
87
+ /**
88
+ * @param output_stride If -1, output_stride is same as block_size
89
+ * @param input_stride If -1, input_stride is same as block_size
90
+ * @param scale_bias_last if false, scale and bias appear at the beginning
91
+ * of each row and are in fp16 for table batched embedding (TBE)
92
+ * in FBGEMM_GPU. If false, it can also take -1 indices (output from
93
+ * pruned embedding id mapping)
94
+ */
95
+ template <
96
+ typename InType,
97
+ typename IndexType,
98
+ typename OffsetType = std::int32_t,
99
+ typename OutType = float,
100
+ bool THREAD_LOCAL = false>
101
+ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
102
+ InType,
103
+ IndexType,
104
+ OffsetType,
105
+ OutType>::Type
106
+ GenerateEmbeddingSpMDMWithStrides(
107
+ const std::int64_t block_size,
108
+ bool has_weight,
109
+ bool normalize_by_lengths,
110
+ int prefetch = 16,
111
+ bool is_weight_positional = false,
112
+ bool use_offsets = true,
113
+ std::int64_t output_stride = -1,
114
+ std::int64_t input_stride = -1,
115
+ bool scale_bias_last = true,
116
+ bool no_bag = false,
117
+ bool is_bf16_out = false,
118
+ bool is_bf16_in = false);
119
+
120
+ /**
121
+ * @tparam IndexType can be int32_t or int64_t
122
+ * @tparam OffsetType can be int32_t or int64_t
123
+ * @param bit_rate can be 2 or 4
124
+ */
125
+ template <
126
+ typename IndexType,
127
+ typename OffsetType = std::int32_t,
128
+ typename OutType = float>
129
+ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
130
+ std::uint8_t,
131
+ IndexType,
132
+ OffsetType,
133
+ OutType>::Type
134
+ GenerateEmbeddingSpMDMNBit(
135
+ int bit_rate,
136
+ const std::int64_t block_size,
137
+ bool has_weight,
138
+ bool normalize_by_lengths,
139
+ int prefetch = 16,
140
+ bool is_weight_positional = false,
141
+ bool use_offsets = true);
142
+
143
+ /**
144
+ * @param output_stride If -1, output_stride is same as block_size
145
+ * @param input_stride in Bytes. If -1, input_stride is same as
146
+ * block_size / num_elem_per_byte + 2 * sizeof(float16)
147
+ * @param scale_bias_last if false, scale and bias appear at the beginning
148
+ * of each row and are in fp16 for table batched embedding (TBE)
149
+ * in FBGEMM_GPU. If false, it can also take -1 indices (output from
150
+ * pruned embedding id mapping)
151
+ */
152
+ template <
153
+ typename IndexType,
154
+ typename OffsetType = std::int32_t,
155
+ typename OutType = float,
156
+ bool THREAD_LOCAL = false>
157
+ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
158
+ std::uint8_t,
159
+ IndexType,
160
+ OffsetType,
161
+ OutType>::Type
162
+ GenerateEmbeddingSpMDMNBitWithStrides(
163
+ const int input_bit_rate,
164
+ const std::int64_t block_size,
165
+ bool has_weight,
166
+ bool normalize_by_lengths,
167
+ int prefetch = 16,
168
+ bool is_weight_positional = false,
169
+ bool use_offsets = true,
170
+ std::int64_t output_stride = -1,
171
+ std::int64_t input_stride = -1,
172
+ bool scale_bias_last = true,
173
+ const bool is_bf16_out = false,
174
+ const bool no_bag = false,
175
+ int output_bit_rate = -1);
176
+
177
+ /**
178
+ * @param output_stride If -1, output_stride is same as block_size
179
+ * @param input_stride in Bytes. If -1, input_stride is same as
180
+ * block_size / num_elem_per_byte + 2 * sizeof(float16)
181
+ * @param exponent_bits is the number of exponent bits in the FP8 encode
182
+ * (normally 4 or 5)
183
+ * @param exponent_bias is subtracted from the exponent to obtain the actual
184
+ * exponent for the floating-point number
185
+ */
186
+ template <
187
+ typename IndexType,
188
+ typename OffsetType = std::int32_t,
189
+ typename OutType = float>
190
+ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
191
+ std::uint8_t,
192
+ IndexType,
193
+ OffsetType,
194
+ OutType>::Type
195
+ GenerateEmbeddingSpMDMFP8WithStrides(
196
+ const std::int64_t block_size,
197
+ bool normalize_by_lengths,
198
+ bool is_weight_positional = false,
199
+ bool use_offsets = true,
200
+ std::int64_t output_stride = -1,
201
+ std::int64_t input_stride = -1,
202
+ int exponent_bits = 4,
203
+ int exponent_bias = 7,
204
+ bool is_bf16_out = false);
205
+
206
+ template <
207
+ typename InType,
208
+ typename IndexType,
209
+ typename OffsetType = std::int32_t>
210
+ class EmbeddingSpMDMRowWiseSparseKernelSignature {
211
+ public:
212
+ using Type = std::function<bool(
213
+ std::int64_t output_size,
214
+ std::int64_t index_size,
215
+ std::int64_t uncompressed_data_size,
216
+ // TODO: add compressed_data_size and check array bound
217
+ const InType* input,
218
+ const IndexType* indices,
219
+ const OffsetType* offsets_or_lengths,
220
+ const float* weights, // optional, can be null for non-weighted sum
221
+ float* out,
222
+ const std::int32_t* compressed_indices_table)>;
223
+ };
224
+
225
+ /**
226
+ * @tparam InType can be float, float16, or uint8_t
227
+ * @tparam IndexType can be int32_t or int64_t
228
+ * @tparam OffsetType can be int32_t or int64_t
229
+ */
230
+ template <
231
+ typename InType,
232
+ typename IndexType,
233
+ typename OffsetType = std::int32_t>
234
+ FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
235
+ InType,
236
+ IndexType,
237
+ OffsetType>::Type
238
+ GenerateEmbeddingSpMDMRowWiseSparse(
239
+ const std::int64_t block_size,
240
+ bool has_weight,
241
+ bool normalize_by_lengths,
242
+ int prefetch = 16,
243
+ bool is_weight_positional = false,
244
+ bool use_offsets = true);
245
+
246
+ /**
247
+ * @tparam IndexType can be int32_t or int64_t
248
+ * @tparam OffsetType can be int32_t or int64_t
249
+ * @param bit_rate can be 2 or 4
250
+ */
251
+ template <typename IndexType, typename OffsetType = std::int32_t>
252
+ FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
253
+ std::uint8_t,
254
+ IndexType,
255
+ OffsetType>::Type
256
+ GenerateEmbeddingSpMDMNBitRowWiseSparse(
257
+ int bit_rate,
258
+ const std::int64_t block_size,
259
+ bool has_weight,
260
+ bool normalize_by_lengths,
261
+ int prefetch = 16,
262
+ bool is_weight_positional = false,
263
+ bool use_offsets = true);
264
+
265
+ /**
266
+ * @return The number of rows processed. If smaller than num_rows, an error
267
+ * must have happened at the last row processed.
268
+ */
269
+ template <typename IndexType>
270
+ class SparseAdaGradSignature {
271
+ public:
272
+ using Type = std::function<int(
273
+ int num_rows, // number of rows reading
274
+ std::uint64_t param_size, // total number of parameters
275
+ float* w, // input/output parameters
276
+ const float* g, // input gradients
277
+ float* h, // input/output momentums
278
+ const IndexType* indices, // indices of each row
279
+ float epsilon,
280
+ float lr,
281
+ float weight_decay,
282
+ const double* counter, // used for weight_decay adjusted for frequency
283
+ // nullptr when frequency adjustment is not used.
284
+ // ignored when the kernel is generated with
285
+ // use_weight_decay = false.
286
+ std::int64_t counter_halflife)>; // frequency adjust happens only after
287
+ };
288
+
289
+ template <typename IndexType>
290
+ FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type
291
+ GenerateSparseAdaGrad(
292
+ int block_size, // number of parameters per row
293
+ bool rowwise = false,
294
+ int prefetch = 16,
295
+ bool use_weight_decay = false);
296
+
297
+ // RowWiseSparseAdaGrad fused with SLS gradient
298
+ // Weights can be either float or float16
299
+ template <
300
+ typename IndexType,
301
+ typename OffsetType = std::int32_t,
302
+ typename DataType = float>
303
+ class RowWiseSparseAdaGradFusedSignature {
304
+ public:
305
+ using Type = std::function<bool(
306
+ std::int64_t output_size,
307
+ std::int64_t index_size,
308
+ std::int64_t data_size, // number of rows in w
309
+ DataType* w, // input/output parameters
310
+ const float* g, // input gradients
311
+ float* h, // input/output momentums
312
+ const IndexType* indices, // indices of each row
313
+ const OffsetType* offsets_or_lengths,
314
+ float epsilon,
315
+ float lr)>;
316
+ };
317
+
318
+ /**
319
+ * @param grad_stride If -1, grad_stride is same as block size
320
+ */
321
+ template <
322
+ typename IndexType,
323
+ typename OffsetType = std::int32_t,
324
+ typename DataType = float>
325
+ FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<
326
+ IndexType,
327
+ OffsetType,
328
+ DataType>::Type
329
+ GenerateRowWiseSparseAdaGradFused(
330
+ int block_size, // number of parameters per row
331
+ int prefetch = 16,
332
+ bool use_offsets = true,
333
+ bool use_stochastic_rounding = true,
334
+ int grad_stride = -1);
335
+
336
+ namespace internal {
337
+ // Specialization for block size 1 internally called by GenerateEmbeddingSpMDM
338
+ template <typename InType, typename IndexType, typename OffsetType>
339
+ FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
340
+ const std::int64_t output_size,
341
+ const std::int64_t index_size,
342
+ const std::int64_t data_size, // the number of rows in input
343
+ const InType* input,
344
+ const IndexType* indices,
345
+ const OffsetType* offsets_or_lengths,
346
+ const float* weights, // optional, can be null for non-weighted sum
347
+ bool normalize_by_lengths,
348
+ float* out,
349
+ bool is_weight_positional = false,
350
+ bool use_offsets = true,
351
+ bool is_bf16 = false);
352
+
353
+ #if !defined(__aarch64__)
354
+ template <typename IndexType, bool HAS_WEIGHTS>
355
+ void compressed_indices_remap_avx512(
356
+ std::int32_t offsets_numel,
357
+ const IndexType* indices,
358
+ const int32_t* compressed_indices_mapping,
359
+ const IndexType* offsets,
360
+ const float* weights, // optional, can be null,
361
+ IndexType* out_indices,
362
+ IndexType* out_offsets,
363
+ float* out_weights);
364
+ #endif
365
+
366
+ } // namespace internal
367
+
368
+ template <typename IndexType>
369
+ FBGEMM_API void compressed_indices_remap(
370
+ std::int32_t offsets_numel,
371
+ const IndexType* indices,
372
+ const int32_t* compressed_indices_mapping,
373
+ const IndexType* offsets,
374
+ const float* weights, // optional, can be null,
375
+ IndexType* out_indices,
376
+ IndexType* out_offsets,
377
+ float* out_weights);
378
+
379
+ } // namespace fbgemm
380
+
381
+ #else
382
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
383
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ // WARNING: this is a legacy fp16 fbgemm implementation and will soon be
13
+ // upgraded to match with new fbgemm interface.
14
+
15
+ #include <cpuinfo.h>
16
+
17
+ #include "./FbgemmPackMatrixB.h" // @manual
18
+ #include "./FloatConversion.h" // @manual
19
+ #include "./Types.h" // @manual
20
+ #include "./Utils.h" // @manual
21
+
22
+ namespace fbgemm {
23
+
24
+ template <>
25
+ struct TypeConverter<float16> {
26
+ float16 operator()(float src) const {
27
+ constexpr float FP16_MAX = 65504.f;
28
+ const float fp16 = std::max(-FP16_MAX, std::min(src, FP16_MAX));
29
+ return cpu_float2half(fp16);
30
+ }
31
+ };
32
+
33
+ using PackedGemmMatrixFP16 = PackedGemmMatrixB<float16>;
34
+
35
+ template <typename T>
36
+ FBGEMM_API void cblas_gemm_compute(
37
+ const matrix_op_t transa,
38
+ const int m,
39
+ const float* A,
40
+ const PackedGemmMatrixB<T>& Bp,
41
+ const float beta,
42
+ float* C,
43
+ int thread_id = 0,
44
+ int num_threads = 1);
45
+
46
+ extern template void cblas_gemm_compute<float16>(
47
+ const matrix_op_t transa,
48
+ const int m,
49
+ const float* A,
50
+ const PackedGemmMatrixFP16& Bp,
51
+ const float beta,
52
+ float* C,
53
+ int thread_id,
54
+ int num_threads);
55
+
56
+ }; // namespace fbgemm
57
+
58
+ #else
59
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
60
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3
+
4
+ #pragma once
5
+
6
+ // WARNING: this is a legacy fp16 fbgemm implementation and will soon be
7
+ // upgraded to match with new fbgemm interface.
8
+
9
+ #include <cpuinfo.h>
10
+
11
+ #include "fbgemm/FbgemmFPCommon.h"
12
+ #include "fbgemm/FbgemmPackMatrixB.h"
13
+ #include "fbgemm/Utils.h"
14
+
15
+ namespace fbgemm {
16
+ template <>
17
+ struct TypeConverter<float> {
18
+ float operator()(float src) const {
19
+ return src;
20
+ }
21
+ };
22
+
23
+ using GemmParamsFP32 = GemmParams<float>;
24
+ using PackedGemmMatrixFP32 = PackedGemmMatrixB<float>;
25
+
26
+ template <typename T, int _kernel_ncol_blocks, int _brow>
27
+ void cblas_gemm_compute(
28
+ const matrix_op_t transa,
29
+ const int m,
30
+ const float* A,
31
+ const PackedGemmMatrixB<T>& Bp,
32
+ const float beta,
33
+ float* C,
34
+ int thread_id = 0,
35
+ int num_threads = 1);
36
+
37
+ extern template void cblas_gemm_compute(
38
+ const matrix_op_t transa,
39
+ const int m,
40
+ const float* A,
41
+ const PackedGemmMatrixFP32& Bp,
42
+ const float beta,
43
+ float* C,
44
+ int thread_id,
45
+ int num_threads);
46
+
47
+ template <>
48
+ const isa_descriptor<float>& getIsaHandlers(inst_set_t isa);
49
+
50
+ } // namespace fbgemm
51
+
52
+ #else
53
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
54
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * Copyright 2024-2025 Arm Limited and/or its affiliates
5
+ * <open-source-office@arm.com> All rights reserved.
6
+ *
7
+ * This source code is licensed under the BSD-style license found in the
8
+ * LICENSE file in the root directory of this source tree.
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include <fbgemm/FbgemmPackMatrixB.h>
14
+ #include <fbgemm/Types.h>
15
+ #include <fbgemm/Utils.h>
16
+ #include <array>
17
+ #include <memory>
18
+
19
+ #if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \
20
+ defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL)
21
+ #if defined(__APPLE__) && defined(__aarch64__)
22
+ #define FBGEMM_USE_REF_KERNEL
23
+ #endif
24
+ #endif
25
+
26
+ namespace fbgemm {
27
+
28
+ using partition_array_t = std::array<std::array<std::array<int, 2>, 2>, 121>;
29
+ extern partition_array_t partition_avx2;
30
+ extern partition_array_t partition_avx512;
31
+ extern partition_array_t partition_sve128;
32
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
33
+ extern partition_array_t partition_neon;
34
+ #endif
35
+
36
+ template <typename T>
37
+ struct GemmParams {
38
+ uint64_t k;
39
+ float* A;
40
+ const T* B;
41
+ float beta;
42
+ float* C;
43
+ uint64_t ldc;
44
+ uint64_t b_block_cols;
45
+ uint64_t b_block_size;
46
+ };
47
+
48
+ template <>
49
+ struct GemmParams<float16> {
50
+ uint64_t k;
51
+ float* A;
52
+ const float16* B;
53
+ float beta;
54
+ float* C;
55
+ uint64_t ldc;
56
+ uint64_t b_block_cols;
57
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
58
+ uint64_t lda;
59
+ #else
60
+ uint64_t b_block_size;
61
+ #endif
62
+ };
63
+
64
+ template <>
65
+ struct GemmParams<float> {
66
+ uint64_t k;
67
+ float* A;
68
+ const float* B;
69
+ float beta;
70
+ float* C;
71
+ uint64_t ldc;
72
+ uint64_t b_block_cols;
73
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
74
+ uint64_t lda;
75
+ #else
76
+ uint64_t b_block_size;
77
+ #endif
78
+ };
79
+
80
+ template <typename T>
81
+ using funcptr_t = void (*)(GemmParams<T>*);
82
+ template <typename T>
83
+ using kernel_array_t = std::array<funcptr_t<T>, 15>;
84
+ template <typename T>
85
+ using isa_descriptor = std::tuple<kernel_array_t<T>, partition_array_t>;
86
+
87
+ template <typename T>
88
+ extern const isa_descriptor<T>& getIsaHandlers(inst_set_t isa);
89
+
90
+ void PackA(int nrow, int ncol, const float* from, int ldim, float* to);
91
+
92
+ // define fp16/fp32 kernels using a reference C implementation
93
+ #if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \
94
+ defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL)
95
+ template <typename T>
96
+ FBGEMM_API void ref_kernel(
97
+ int kernel_nrows,
98
+ GemmParams<T>* gp,
99
+ const float* C_base,
100
+ int m_total,
101
+ int n_total,
102
+ int vlen);
103
+ #endif
104
+
105
+ template <typename T>
106
+ FBGEMM_API void cblas_gemm_compute(
107
+ const matrix_op_t transa,
108
+ const int m,
109
+ const float* A,
110
+ const PackedGemmMatrixB<T>& Bp,
111
+ const float beta,
112
+ float* C,
113
+ int thread_id = 0,
114
+ int num_threads = 1);
115
+
116
+ #if defined(FBGEMM_EXPORTS)
117
+ // autotuned kernel splits for various cases m = 1:mb_max
118
+ template <typename T>
119
+ void cblas_gemm_compute(
120
+ const matrix_op_t transa [[maybe_unused]],
121
+ const int m,
122
+ const float* A,
123
+ const PackedGemmMatrixB<T>& Bp,
124
+ const float beta,
125
+ float* C,
126
+ int thread_id,
127
+ int num_threads) {
128
+ // ground truth
129
+ assert(cpuinfo_initialize());
130
+ #ifndef __aarch64__
131
+ assert(cpuinfo_has_x86_fma3());
132
+ assert(cpuinfo_has_x86_f16c());
133
+ #endif
134
+ assert(transa == matrix_op_t::NoTranspose);
135
+
136
+ // private scratchpad storage
137
+ static thread_local std::unique_ptr<std::array<float, 256 * 1024>> scratchpad(
138
+ new std::array<float, 256 * 1024>());
139
+
140
+ // constants
141
+ const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
142
+ const int mb_max = 120;
143
+
144
+ #if defined(FBGEMM_USE_REF_KERNEL) && defined(__APPLE__)
145
+ const auto& [_, partition] = getIsaHandlers<float16>(inst_set_t::sve);
146
+ #else
147
+ const auto iset = fbgemmInstructionSet();
148
+ const auto& [kernels, partition] = getIsaHandlers<T>(iset);
149
+ #endif
150
+
151
+ #ifdef FBGEMM_USE_REF_KERNEL
152
+ // By some reason, if packed B is using packing layout for avx2, we just use
153
+ // avx2 even if avx512 is available.
154
+ const int simd_width =
155
+ #ifndef __aarch64__
156
+ (iset == inst_set_t::avx512 || iset == inst_set_t::avx512_vnni) &&
157
+ (Bp.blockColSize() == 16 * Bp.kernelNumColBlocks())
158
+ ? simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS
159
+ : simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
160
+ #else
161
+ simd_info<inst_set_t::sve>::WIDTH_32BIT_ELEMS;
162
+ #endif
163
+ #endif
164
+
165
+ GemmParams<T> gp;
166
+ int i_begin = 0, i_end = 0;
167
+ i_begin = 0;
168
+ i_end = m;
169
+ for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
170
+ int mb = std::min(mb_max, i_end - m0);
171
+ assert(mb < static_cast<int64_t>(partition.size()));
172
+ for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
173
+ // set up proper accumulation to avoid "Nan" problem
174
+ // accumulate of beta != 0.0
175
+ // do not!!! accumulate otherwise
176
+ float beta_ = beta;
177
+ if (k_ind != 0) {
178
+ // always accumulate with beta_ = 1.0f
179
+ beta_ = 1.0f;
180
+ }
181
+
182
+ const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind);
183
+
184
+ auto m1 = m0;
185
+ auto const num_cycles = partition[mb].size();
186
+ for (size_t c = 0; c < num_cycles; ++c) {
187
+ auto kernel_nrows = partition[mb][c][0];
188
+ auto nkernel_nrows = partition[mb][c][1];
189
+ auto m_start = m1;
190
+ auto m_end = m1 + kernel_nrows * nkernel_nrows;
191
+ for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
192
+ assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
193
+ if (m != 1) {
194
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
195
+ if constexpr (
196
+ std::is_same<T, float16>::value ||
197
+ std::is_same<T, float>::value) {
198
+ gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
199
+ } else {
200
+ #endif
201
+ PackA(
202
+ kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
203
+ gp.A = scratchpad->data();
204
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
205
+ }
206
+ #endif
207
+ } else {
208
+ // When m == 1, it is actually vector matrix multiplication. We
209
+ // don't need to do the transposition for packA here. Instead, we
210
+ // can just pass the pointer of the original A matrix buffer to the
211
+ // packed A buffer.
212
+ gp.A = const_cast<float*>(&A[k_ind]);
213
+ }
214
+
215
+ int nbcol = n / Bp.blockColSize();
216
+ gp.k = kb;
217
+ gp.B = &(Bp(k_ind, 0));
218
+ gp.beta = beta_;
219
+ gp.C = &C[m2 * ldc];
220
+ gp.ldc = ldc * sizeof(C[0]);
221
+ gp.b_block_cols = nbcol;
222
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
223
+ if constexpr (
224
+ std::is_same<T, float16>::value ||
225
+ std::is_same<T, float>::value) {
226
+ gp.lda = k * sizeof(A[0]);
227
+ } else {
228
+ #endif
229
+ gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
230
+ #ifdef FBGEMM_ENABLE_KLEIDIAI
231
+ }
232
+ #endif
233
+ if ((n % Bp.blockColSize()) == 0) {
234
+ int64_t jb_begin = 0, jb_end = 0;
235
+ fbgemmPartition1D(
236
+ thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
237
+ gp.B += gp.k * Bp.blockColSize() * jb_begin;
238
+ gp.C += Bp.blockColSize() * jb_begin;
239
+ gp.b_block_cols = jb_end - jb_begin;
240
+ if (gp.b_block_cols) {
241
+ #ifdef FBGEMM_USE_REF_KERNEL
242
+ ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
243
+ #else
244
+ kernels[kernel_nrows](&gp);
245
+ #endif
246
+ }
247
+ } else {
248
+ int last_blk_col = nbcol * Bp.blockColSize();
249
+ if (nbcol) {
250
+ int64_t jb_begin = 0, jb_end = 0;
251
+ fbgemmPartition1D(
252
+ thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
253
+ gp.B += gp.k * Bp.blockColSize() * jb_begin;
254
+ gp.C += Bp.blockColSize() * jb_begin;
255
+ gp.b_block_cols = jb_end - jb_begin;
256
+ if (gp.b_block_cols) {
257
+ #ifdef FBGEMM_USE_REF_KERNEL
258
+ ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
259
+ #else
260
+ kernels[kernel_nrows](&gp);
261
+ #endif
262
+ }
263
+ }
264
+
265
+ // use one thread to handle the fringe cases
266
+ if (thread_id == num_threads - 1) {
267
+ // leftover
268
+ const int rem [[maybe_unused]] = n - last_blk_col;
269
+ assert(rem < Bp.blockColSize());
270
+
271
+ // small temporary buffer: the size should be larger than the
272
+ // required kernel_nrow x kernel_ncols elements computed in the
273
+ // registers.
274
+ std::array<float, 14 * 32> c_tmp{0.f};
275
+ assert(
276
+ static_cast<int64_t>(c_tmp.size()) >=
277
+ kernel_nrows * Bp.blockColSize());
278
+
279
+ gp.B = &(Bp(k_ind, last_blk_col));
280
+ gp.C = c_tmp.data();
281
+ gp.ldc = Bp.blockColSize() * sizeof(C[0]);
282
+ gp.b_block_cols = 1;
283
+ #ifdef FBGEMM_USE_REF_KERNEL
284
+ ref_kernel<T>(
285
+ kernel_nrows, &gp, c_tmp.data(), 14, 32, simd_width);
286
+ #else
287
+ kernels[kernel_nrows](&gp);
288
+ #endif
289
+ for (int i = 0; i < kernel_nrows; i++) {
290
+ // Todo: use assembly
291
+ for (int j = last_blk_col; j < n; j++) {
292
+ assert(
293
+ i * Bp.blockColSize() + (j - last_blk_col) <
294
+ static_cast<int64_t>(sizeof(c_tmp) / sizeof(c_tmp[0])));
295
+ if (beta_ == 0.f) {
296
+ C[(m2 + i) * ldc + j] =
297
+ c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
298
+ } else {
299
+ C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
300
+ c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
301
+ }
302
+ }
303
+ }
304
+ }
305
+ }
306
+ }
307
+ m1 += kernel_nrows * nkernel_nrows;
308
+ }
309
+ }
310
+ }
311
+ }
312
+ #endif
313
+
314
+ #undef FBGEMM_USE_REF_KERNEL
315
+ } // namespace fbgemm
316
+
317
+ #else
318
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
319
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include <cstdint>
13
+
14
+ #include "fbgemm/Utils.h"
15
+
16
+ namespace fbgemm {
17
+
18
+ FBGEMM_API void cblas_gemm_i64_i64acc(
19
+ matrix_op_t transa,
20
+ matrix_op_t transb,
21
+ int M,
22
+ int N,
23
+ int K,
24
+ const std::int64_t* A,
25
+ int lda,
26
+ const std::int64_t* B,
27
+ int ldb,
28
+ bool accumulate,
29
+ std::int64_t* C,
30
+ int ldc);
31
+
32
+ } // namespace fbgemm
33
+
34
+ #else
35
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
36
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ * All rights reserved.
5
+ *
6
+ * This source code is licensed under the BSD-style license found in the
7
+ * LICENSE file in the root directory of this source tree.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include <cstdint>
13
+ #include "fbgemm/ConvUtils.h"
14
+ #include "fbgemm/FbgemmBuild.h"
15
+ #include "fbgemm/UtilsAvx2.h"
16
+
17
+ namespace fbgemm {
18
+
19
+ class FBGEMM_API PackedDepthWiseConvMatrix {
20
+ public:
21
+ /**
22
+ * @param IC the number of input channels (same as the number of groups
23
+ * because depth-wise convolution has one input channel per group)
24
+ * @param OC the number of output channels
25
+ * @param kernel_prod the product of all kernels. For example, kernel_prod =
26
+ * 9 for 3x3 conv, and 27 for 3x3x3 conv.
27
+ * @param smat the source unpacked weight in GRS layout
28
+ */
29
+ PackedDepthWiseConvMatrix(int OC, int kernel_prod, const std::int8_t* smat);
30
+ PackedDepthWiseConvMatrix(const PackedDepthWiseConvMatrix&) = delete;
31
+ PackedDepthWiseConvMatrix(PackedDepthWiseConvMatrix&&) = delete;
32
+ PackedDepthWiseConvMatrix& operator=(const PackedDepthWiseConvMatrix&) =
33
+ delete;
34
+ PackedDepthWiseConvMatrix& operator=(PackedDepthWiseConvMatrix&&) = delete;
35
+ virtual ~PackedDepthWiseConvMatrix();
36
+
37
+ const std::int8_t* PackedMat() const {
38
+ return pmat_;
39
+ }
40
+
41
+ int GetKernelProduct() const {
42
+ return kernel_prod_;
43
+ }
44
+
45
+ /**
46
+ * @brief Unpacks pmat_ into unpack_data.
47
+ * Used for recovering the weight matrix into the original format
48
+ */
49
+ void unpack(std::int8_t* unpacked_data);
50
+
51
+ /**
52
+ * @brief returns the index into pmat_ given the row and column for smat
53
+ */
54
+ int addr(int r, int c);
55
+
56
+ private:
57
+ const int OC_; /**< the number of output channels */
58
+ const int kernel_prod_; /** the product of all kernel dims */
59
+ std::int8_t* pmat_; /** packed weight */
60
+ }; // PackedDepthWiseConvMatrix
61
+
62
+ /**
63
+ * Depth-wise convolution that results in the same output feature size as the
64
+ * input feature. That is PAD_T = PAD_B = (R - 1) / 2 and PAD_L = PAD_R =
65
+ * (S - 1) / 2. This function also does requantization.
66
+ * @param col_offsets nullptr if col_offsets are folded into bias
67
+ * @param act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is
68
+ * unquantized.
69
+ */
70
+ template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE = std::int32_t>
71
+ FBGEMM_API void depthwise_2d_same_pad(
72
+ int N,
73
+ int H,
74
+ int W,
75
+ int IC,
76
+ int OC,
77
+ int stride_h,
78
+ int stride_w,
79
+ std::int32_t A_zero_point,
80
+ const std::uint8_t* A,
81
+ const std::int32_t* B_zero_point,
82
+ const PackedDepthWiseConvMatrix& Bp,
83
+ const float* C_multiplier,
84
+ std::int32_t C_zero_point,
85
+ std::uint8_t* C,
86
+ const std::int32_t* col_offsets,
87
+ const BIAS_TYPE* bias,
88
+ bool fuse_relu = false,
89
+ const float* act_times_w_scale = nullptr,
90
+ int thread_id = 0,
91
+ int num_threads = 1);
92
+
93
+ /**
94
+ * @param col_offsets nullptr if col_offsets are folded into bias
95
+ */
96
+ template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE = std::int32_t>
97
+ FBGEMM_API void depthwise_3d_same_pad(
98
+ const conv_param_t<3>& conv_p,
99
+ std::int32_t A_zero_point,
100
+ const std::uint8_t* A,
101
+ const std::int32_t* B_zero_point,
102
+ const PackedDepthWiseConvMatrix& Bp,
103
+ const float* C_multiplier,
104
+ std::int32_t C_zero_point,
105
+ std::uint8_t* C,
106
+ const std::int32_t* col_offsets,
107
+ const BIAS_TYPE* bias,
108
+ bool fuse_relu = false,
109
+ const float* act_times_w_scale = nullptr,
110
+ int thread_id = 0,
111
+ int num_threads = 1);
112
+
113
+ } // namespace fbgemm
114
+
115
+ #else
116
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
117
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)