koichi12 commited on
Commit
104185d
·
verified ·
1 Parent(s): 44b4c93

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h +173 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h +2 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/CPUApplyUtils.h +343 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h +33 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions.h +29 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h +49 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h +553 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h +29 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h +29 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +25 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/Config.h +21 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/Context.h +610 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/Device.h +2 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/DeviceAccelerator.h +27 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h +808 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h +166 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h +30 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/FuncTorchTLS.h +46 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h +208 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +454 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/InferSize.h +88 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h +15 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h +2 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h +25 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +160 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapMode.h +26 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapTransforms.h +183 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h +143 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h +107 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h +29 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h +443 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h +1 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h +214 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h +1344 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h +286 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/OpMathType.h +69 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h +28 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h +158 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/ParallelFuture.h +13 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/RegistrationDeclarations.h +0 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h +66 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/Scalar.h +3 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/ScalarOps.h +53 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/ScalarType.h +4 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorImpl.h +206 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h +441 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/Storage.h +2 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/Tensor.h +3 -0
.gitattributes CHANGED
@@ -145,3 +145,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
145
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
146
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
147
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
145
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
146
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
147
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
148
+ .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b6d007391b31b1b010874c0dfd08680792c26a7542040197359230b54bc6612
3
+ size 114085
.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Config.h>
3
+ #include <c10/core/DeviceType.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Float8_e4m3fn.h>
7
+ #include <c10/util/Float8_e4m3fnuz.h>
8
+ #include <c10/util/Float8_e5m2.h>
9
+ #include <c10/util/Float8_e5m2fnuz.h>
10
+ #include <c10/util/Half.h>
11
+
12
+ // Defines the accumulation type for a scalar type.
13
+ // Example:
14
+ // using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
15
+ //
16
+ // Accumulation types are an important concept in numeric computing
17
+ // because you frequently want to perform intermediate computations
18
+ // at a higher precision than the input and output precision, to avoid
19
+ // compounding internal rounding errors. Accumulation is the most
20
+ // well-known intermediate computation (it is of great importance for
21
+ // sum reduction and matrix multiply, for example), but in PyTorch
22
+ // acc_type ends up getting used for all sorts of other intermediate
23
+ // computations, so it perhaps would be more accurately (ahem) called an
24
+ // "accurate" type. acc_type is especially important for reduced
25
+ // precision operations like float16 and bfloat16, where relatively
26
+ // benign looking inputs can easily end up overflowing/underflowing.
27
+ //
28
+ // acc_type is parametrized by whether or not you are running on CUDA
29
+ // or not, because on CUDA double precision operations are expensive
30
+ // and so by default, we don't actually want to use double as an
31
+ // acc_type on CUDA. A lot of things are typed out below, but
32
+ // basically, the table is generated by a few rules:
33
+ //
34
+ // If bool:
35
+ // Use 'bool' as acc_type.
36
+ // If floating point:
37
+ // If CUDA, use 'float' as acc_type (unless scalar_t is double),
38
+ // otherwise (CPU) use 'double'
39
+ // If integral:
40
+ // Use 'int64_t' as acc_type
41
+ //
42
+ // You're not forced to use this template; if you happen to know
43
+ // something specific about your use case, you can specify your own
44
+ // desired behavior. This template, however, will give you a reasonable
45
+ // default that will work for all dtypes supported in PyTorch.
46
+
47
+ #if defined(__CUDACC__)
48
+ #include <cuda.h>
49
+ #include <cuda_fp16.h>
50
+ #elif defined(__HIPCC__)
51
+ #include <hip/hip_fp16.h>
52
+ #include <hip/hip_runtime.h>
53
+ #endif
54
+
55
+ namespace at {
56
+
57
+ template <typename T, c10::DeviceType D>
58
+ struct AccumulateTypeDevice {};
59
+
60
+ template <typename T, bool>
61
+ struct AccumulateType {};
62
+
63
+ template <typename T>
64
+ struct AccumulateType<T, false> {
65
+ using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
66
+ };
67
+
68
+ template <typename T>
69
+ struct AccumulateType<T, true> {
70
+ using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
71
+ };
72
+
73
+ template <typename T, c10::DeviceType device>
74
+ using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
75
+
76
+ template <typename T, bool is_cuda>
77
+ using acc_type = typename AccumulateType<T, is_cuda>::type;
78
+
79
+ #define ACC_TYPE(t, acc_t, device_type) \
80
+ template <> \
81
+ struct AccumulateTypeDevice<t, device_type> { \
82
+ using type = acc_t; \
83
+ };
84
+ #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
85
+ #define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
86
+ #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
87
+ #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
88
+
89
+ MPS_ACC_TYPE(BFloat16, float);
90
+ MPS_ACC_TYPE(Half, float);
91
+ MPS_ACC_TYPE(Float8_e5m2, float);
92
+ MPS_ACC_TYPE(Float8_e4m3fn, float);
93
+ MPS_ACC_TYPE(Float8_e5m2fnuz, float);
94
+ MPS_ACC_TYPE(Float8_e4m3fnuz, float);
95
+ MPS_ACC_TYPE(float, float);
96
+ MPS_ACC_TYPE(double, float);
97
+ MPS_ACC_TYPE(int8_t, int64_t);
98
+ MPS_ACC_TYPE(uint8_t, int64_t);
99
+ MPS_ACC_TYPE(char, int64_t);
100
+ MPS_ACC_TYPE(int16_t, int64_t);
101
+ MPS_ACC_TYPE(int32_t, int64_t);
102
+ MPS_ACC_TYPE(int64_t, int64_t);
103
+ MPS_ACC_TYPE(bool, bool);
104
+ MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
105
+ MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
106
+ MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
107
+
108
+ XPU_ACC_TYPE(BFloat16, float);
109
+ XPU_ACC_TYPE(Half, float);
110
+ XPU_ACC_TYPE(Float8_e5m2, float);
111
+ XPU_ACC_TYPE(Float8_e4m3fn, float);
112
+ XPU_ACC_TYPE(Float8_e5m2fnuz, float);
113
+ XPU_ACC_TYPE(Float8_e4m3fnuz, float);
114
+ XPU_ACC_TYPE(float, float);
115
+ XPU_ACC_TYPE(double, double);
116
+ XPU_ACC_TYPE(int8_t, int64_t);
117
+ XPU_ACC_TYPE(uint8_t, int64_t);
118
+ XPU_ACC_TYPE(char, int64_t);
119
+ XPU_ACC_TYPE(int16_t, int64_t);
120
+ XPU_ACC_TYPE(int32_t, int64_t);
121
+ XPU_ACC_TYPE(int64_t, int64_t);
122
+ XPU_ACC_TYPE(bool, bool);
123
+ XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
124
+ XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
125
+ XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
126
+
127
+ #if defined(__CUDACC__) || defined(__HIPCC__)
128
+ CUDA_ACC_TYPE(half, float);
129
+ #endif
130
+ CUDA_ACC_TYPE(BFloat16, float);
131
+ CUDA_ACC_TYPE(Half, float);
132
+ CUDA_ACC_TYPE(Float8_e5m2, float);
133
+ CUDA_ACC_TYPE(Float8_e4m3fn, float);
134
+ CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
135
+ CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
136
+ CUDA_ACC_TYPE(float, float);
137
+ CUDA_ACC_TYPE(double, double);
138
+ CUDA_ACC_TYPE(int8_t, int64_t);
139
+ CUDA_ACC_TYPE(uint8_t, int64_t);
140
+ CUDA_ACC_TYPE(char, int64_t);
141
+ CUDA_ACC_TYPE(int16_t, int64_t);
142
+ CUDA_ACC_TYPE(int32_t, int64_t);
143
+ CUDA_ACC_TYPE(int64_t, int64_t);
144
+ CUDA_ACC_TYPE(bool, bool);
145
+ CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
146
+ CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
147
+ CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
148
+
149
+ CPU_ACC_TYPE(BFloat16, float);
150
+ CPU_ACC_TYPE(Half, float);
151
+ CPU_ACC_TYPE(Float8_e5m2, float);
152
+ CPU_ACC_TYPE(Float8_e4m3fn, float);
153
+ CPU_ACC_TYPE(Float8_e5m2fnuz, float);
154
+ CPU_ACC_TYPE(Float8_e4m3fnuz, float);
155
+ CPU_ACC_TYPE(float, double);
156
+ CPU_ACC_TYPE(double, double);
157
+ CPU_ACC_TYPE(int8_t, int64_t);
158
+ CPU_ACC_TYPE(uint8_t, int64_t);
159
+ CPU_ACC_TYPE(char, int64_t);
160
+ CPU_ACC_TYPE(int16_t, int64_t);
161
+ CPU_ACC_TYPE(int32_t, int64_t);
162
+ CPU_ACC_TYPE(int64_t, int64_t);
163
+ CPU_ACC_TYPE(bool, bool);
164
+ CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
165
+ CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
166
+ CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
167
+
168
+ TORCH_API c10::ScalarType toAccumulateType(
169
+ c10::ScalarType type,
170
+ c10::DeviceType device);
171
+ TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
172
+
173
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/core/Backend.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUApplyUtils.h ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/CollapseDims.h>
4
+ #include <ATen/Parallel.h>
5
+ #include <ATen/TensorUtils.h>
6
+ #include <c10/util/irange.h>
7
+ #include <cstring>
8
+ #include <limits>
9
+
10
+ namespace at {
11
+
12
+ /*
13
+ * The basic strategy for apply is as follows:
14
+ *
15
+ * 1. Starting with the outermost index, loop until we reach a dimension where
16
+ * the data is no longer contiguous, i.e. the stride at that dimension is not
17
+ * equal to the size of the tensor defined by the outer dimensions. Let's call
18
+ * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
19
+ * A is equal to the entire Tensor. Let's call the inner tensor B.
20
+ *
21
+ * 2. We loop through the indices in B, starting at its outermost dimension. For
22
+ * example, if B is a 2x2 matrix, then we do:
23
+ *
24
+ * B[0][0]
25
+ * B[0][1]
26
+ * B[1][0]
27
+ * B[1][1]
28
+ *
29
+ * We set the offset into the underlying storage as (storageOffset + stride_B *
30
+ * index_B), i.e. basically we compute the offset into the storage as we would
31
+ * normally for a Tensor. But because we are guaranteed the subsequent data is
32
+ * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
33
+ * the operation, without having to follow the order described by the strides of
34
+ * A.
35
+ *
36
+ * 3. As an optimization, we merge dimensions of A that are contiguous in
37
+ * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
38
+ * then the first two dimensions can be merged for the purposes of APPLY,
39
+ * reducing the number of nested loops.
40
+ */
41
+
42
+ inline Tensor sort_strides(Tensor& tensor_) {
43
+ IntArrayRef strides = tensor_.strides();
44
+ std::vector<int64_t> indices;
45
+ indices.reserve(tensor_.ndimension());
46
+ for (const auto i : c10::irange(tensor_.ndimension())) {
47
+ indices.push_back(i);
48
+ }
49
+ std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
50
+ return strides[i1] > strides[i2];
51
+ });
52
+ Tensor tensor = tensor_.permute(indices);
53
+ return tensor;
54
+ }
55
+
56
+ template <typename T, int N>
57
+ struct strided_tensor_iter_fixed {
58
+ public:
59
+ T* data_ = NULL;
60
+ int64_t dim_ = 0;
61
+
62
+ int64_t counter_[N] = {0};
63
+ int64_t sizes_[N] = {0};
64
+ int64_t strides_[N] = {0};
65
+
66
+ strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
67
+ void operator=(strided_tensor_iter_fixed const& x) = delete;
68
+ strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
69
+ strided_tensor_iter_fixed(
70
+ Tensor& tensor,
71
+ C10_UNUSED bool sort_strides = false)
72
+ : data_(tensor.data_ptr<T>()) {
73
+ std::memset(counter_, 0, sizeof(int64_t) * N);
74
+ if (tensor.dim() > 0) {
75
+ std::memcpy(
76
+ sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77
+ std::memcpy(
78
+ strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79
+ }
80
+ dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81
+ }
82
+ };
83
+
84
+ template <typename T>
85
+ struct strided_tensor_iter {
86
+ private:
87
+ public:
88
+ T* data_ = NULL;
89
+ int64_t dim_;
90
+
91
+ std::vector<int64_t> counter_;
92
+ std::vector<int64_t> sizes_;
93
+ std::vector<int64_t> strides_;
94
+
95
+ strided_tensor_iter(strided_tensor_iter const&) = delete;
96
+ void operator=(strided_tensor_iter const& x) = delete;
97
+ strided_tensor_iter(strided_tensor_iter&&) = default;
98
+ strided_tensor_iter(Tensor& tensor)
99
+ : data_(tensor.data_ptr<T>()),
100
+ dim_(tensor.ndimension()),
101
+ counter_(dim_, 0),
102
+ sizes_(tensor.sizes().vec()),
103
+ strides_(tensor.strides().vec()) {
104
+ dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105
+ }
106
+ };
107
+
108
+ inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109
+ if (tensors.empty())
110
+ return true;
111
+ int64_t all_numel = tensors[0].numel();
112
+ for (const auto i : c10::irange(1, tensors.size())) {
113
+ if (tensors[i].numel() != all_numel)
114
+ return false;
115
+ }
116
+ return true;
117
+ }
118
+
119
+ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120
+ std::ostringstream oss;
121
+ oss << "inconsistent tensor size, expected ";
122
+ for (size_t i = 0; i < tensors.size() - 1; i++) {
123
+ oss << tensors[i].sizes() << ", ";
124
+ }
125
+ oss << "and " << tensors[tensors.size() - 1].sizes()
126
+ << " to have the same number of elements, but got ";
127
+ for (size_t i = 0; i < tensors.size() - 1; i++) {
128
+ oss << tensors[i].numel() << ", ";
129
+ }
130
+ oss << "and " << tensors[tensors.size() - 1].numel()
131
+ << " elements respectively";
132
+ return oss.str();
133
+ }
134
+
135
+ inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136
+ checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137
+ checkLayout("CPU_tensor_apply", tensors, kStrided);
138
+ if (!_all_equal_numel(tensors))
139
+ AT_ERROR(_all_equal_numel_error(tensors));
140
+ // An empty tensor has no elements
141
+ for (auto& t : tensors)
142
+ if (t.numel() == 0)
143
+ return false;
144
+ return true;
145
+ }
146
+
147
+ inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148
+ int64_t dim = 0;
149
+ for (auto& t : tensors)
150
+ dim = std::max(dim, t.ndimension());
151
+ return dim;
152
+ }
153
+
154
+ inline void iterate(int64_t /*size*/){};
155
+
156
+ template <typename Arg, typename... Args>
157
+ inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158
+ iter.counter_[iter.dim_ - 1] += size;
159
+ iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160
+ iterate(size, iter_tail...);
161
+ }
162
+
163
+ inline bool iterate_continue() {
164
+ return true;
165
+ };
166
+
167
+ template <typename Arg, typename... Args>
168
+ inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169
+ return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170
+ iterate_continue(iter_tail...);
171
+ }
172
+
173
+ inline int64_t max_iterate_size() {
174
+ return std::numeric_limits<int64_t>::max();
175
+ };
176
+
177
+ template <typename Arg, typename... Args>
178
+ inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179
+ return std::min(
180
+ (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181
+ max_iterate_size(iter_tail...));
182
+ }
183
+
184
+ inline void iterate_overflow(){};
185
+
186
+ template <typename Arg, typename... Args>
187
+ inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188
+ if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189
+ for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190
+ if (iter.counter_[i] == iter.sizes_[i]) {
191
+ iter.counter_[i] = 0;
192
+ iter.counter_[i - 1]++;
193
+ iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194
+ iter.strides_[i - 1];
195
+ }
196
+ }
197
+ }
198
+ iterate_overflow(iter_tail...);
199
+ }
200
+
201
+ inline void forward(int64_t /*offset*/){};
202
+
203
+ template <typename Arg, typename... Args>
204
+ inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205
+ int64_t multi = offset;
206
+ for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207
+ int64_t inc = multi % iter.sizes_[i];
208
+ multi = multi / iter.sizes_[i];
209
+ iter.data_ = iter.data_ + inc * iter.strides_[i];
210
+ iter.counter_[i] += inc;
211
+ }
212
+ forward(offset, iter_tail...);
213
+ }
214
+
215
+ inline int64_t max_dim() {
216
+ return 0;
217
+ }
218
+
219
+ template <typename Arg, typename... Args>
220
+ inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221
+ return std::max(iter.dim_, max_dim(iter_tail...));
222
+ }
223
+
224
+ inline void apply_op(){};
225
+
226
+ template <typename Op, typename... Args>
227
+ inline void apply_op(
228
+ int64_t numel,
229
+ int64_t offset,
230
+ const Op& op,
231
+ Args... iters) {
232
+ // For 0-dim tensors
233
+ if (numel == 1 && max_dim(iters...) == 0) {
234
+ op(*iters.data_...);
235
+ return;
236
+ }
237
+ if (offset > 0)
238
+ forward(offset, iters...);
239
+ // Splitting this into chunks helps the compiler create faster assembly
240
+ for (int64_t i = 0; i < numel;) {
241
+ for (; iterate_continue(iters...) && i < numel;) {
242
+ op(*iters.data_...);
243
+ iterate(1, iters...);
244
+ i++;
245
+ }
246
+ iterate_overflow(iters...);
247
+ }
248
+ }
249
+
250
+ /*
251
+ Apply a pointwise operator to sequence of tensors
252
+
253
+ The calling convention for op is a function/functor that takes the same
254
+ number of pointers of type scalar as the number of given tensors. For example,
255
+ to compute a = b * c, op would be of the form:
256
+ [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257
+ b_val[0] * c_val[0]; };
258
+ */
259
+
260
+ template <typename scalar1, typename scalar2, typename Op>
261
+ inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262
+ if (!_apply_preamble({tensor1, tensor2}))
263
+ return;
264
+ if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265
+ apply_op(
266
+ tensor1.numel(),
267
+ 0,
268
+ op,
269
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271
+ } else {
272
+ apply_op(
273
+ tensor1.numel(),
274
+ 0,
275
+ op,
276
+ strided_tensor_iter<scalar1>(tensor1),
277
+ strided_tensor_iter<scalar2>(tensor2));
278
+ }
279
+ }
280
+
281
+ template <typename scalar1, typename scalar2, typename scalar3, typename Op>
282
+ inline void CPU_tensor_apply3(
283
+ Tensor tensor1,
284
+ Tensor tensor2,
285
+ Tensor tensor3,
286
+ const Op op) {
287
+ if (!_apply_preamble({tensor1, tensor2, tensor3}))
288
+ return;
289
+ if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290
+ apply_op(
291
+ tensor1.numel(),
292
+ 0,
293
+ op,
294
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296
+ strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297
+ } else {
298
+ apply_op(
299
+ tensor1.numel(),
300
+ 0,
301
+ op,
302
+ strided_tensor_iter<scalar1>(tensor1),
303
+ strided_tensor_iter<scalar2>(tensor2),
304
+ strided_tensor_iter<scalar3>(tensor3));
305
+ }
306
+ }
307
+
308
+ template <
309
+ typename scalar1,
310
+ typename scalar2,
311
+ typename scalar3,
312
+ typename scalar4,
313
+ typename Op>
314
+ inline void CPU_tensor_apply4(
315
+ Tensor tensor1,
316
+ Tensor tensor2,
317
+ Tensor tensor3,
318
+ Tensor tensor4,
319
+ const Op op) {
320
+ if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321
+ return;
322
+ if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323
+ apply_op(
324
+ tensor1.numel(),
325
+ 0,
326
+ op,
327
+ strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328
+ strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329
+ strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330
+ strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331
+ } else {
332
+ apply_op(
333
+ tensor1.numel(),
334
+ 0,
335
+ op,
336
+ strided_tensor_iter<scalar1>(tensor1),
337
+ strided_tensor_iter<scalar2>(tensor2),
338
+ strided_tensor_iter<scalar3>(tensor3),
339
+ strided_tensor_iter<scalar4>(tensor4));
340
+ }
341
+ }
342
+
343
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Allocator.h>
4
+ #include <c10/util/Exception.h>
5
+
6
+ // This file creates a fake allocator that just throws exceptions if
7
+ // it is actually used.
8
+
9
+ // state passed to the allocator is the std::function<void(void*)> called
10
+ // when the blob is release by ATen
11
+
12
+ namespace at {
13
+
14
+ static cpu_fixed_malloc(void*, ptrdiff_t) {
15
+ AT_ERROR("attempting to resize a tensor view of an external blob");
16
+ }
17
+
18
+ static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
19
+ AT_ERROR("attempting to resize a tensor view of an external blob");
20
+ }
21
+
22
+ static cpu_fixed_free(void* state, void* allocation) {
23
+ auto on_release = static_cast<std::function<void(void*)>*>(state);
24
+ (*on_release)(allocation);
25
+ delete on_release;
26
+ }
27
+
28
+ static Allocator CPU_fixed_allocator = {
29
+ cpu_fixed_malloc,
30
+ cpu_fixed_realloc,
31
+ cpu_fixed_free};
32
+
33
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable std::optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ #include <ATen/CPUFunctions_inl.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Generator.h>
4
+ #include <ATen/core/MT19937RNGEngine.h>
5
+ #include <c10/core/GeneratorImpl.h>
6
+ #include <optional>
7
+
8
+ namespace at {
9
+
10
+ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
11
+ // Constructors
12
+ CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
13
+ ~CPUGeneratorImpl() override = default;
14
+
15
+ // CPUGeneratorImpl methods
16
+ std::shared_ptr<CPUGeneratorImpl> clone() const;
17
+ void set_current_seed(uint64_t seed) override;
18
+ void set_offset(uint64_t offset) override;
19
+ uint64_t get_offset() const override;
20
+ uint64_t current_seed() const override;
21
+ uint64_t seed() override;
22
+ void set_state(const c10::TensorImpl& new_state) override;
23
+ c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
24
+ static c10::DeviceType device_type();
25
+ uint32_t random();
26
+ uint64_t random64();
27
+ std::optional<float> next_float_normal_sample();
28
+ std::optional<double> next_double_normal_sample();
29
+ void set_next_float_normal_sample(std::optional<float> randn);
30
+ void set_next_double_normal_sample(std::optional<double> randn);
31
+ at::mt19937 engine();
32
+ void set_engine(at::mt19937 engine);
33
+
34
+ private:
35
+ CPUGeneratorImpl* clone_impl() const override;
36
+ at::mt19937 engine_;
37
+ std::optional<float> next_float_normal_sample_;
38
+ std::optional<double> next_double_normal_sample_;
39
+ };
40
+
41
+ namespace detail {
42
+
43
+ TORCH_API const Generator& getDefaultCPUGenerator();
44
+ TORCH_API Generator
45
+ createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
46
+
47
+ } // namespace detail
48
+
49
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_compositeexplicitautograd_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ #include <ATen/ops/_adaptive_avg_pool2d_compositeexplicitautograd_dispatch.h>
20
+ #include <ATen/ops/_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
21
+ #include <ATen/ops/_adaptive_avg_pool3d_compositeexplicitautograd_dispatch.h>
22
+ #include <ATen/ops/_adaptive_avg_pool3d_backward_compositeexplicitautograd_dispatch.h>
23
+ #include <ATen/ops/_add_relu_compositeexplicitautograd_dispatch.h>
24
+ #include <ATen/ops/_aminmax_compositeexplicitautograd_dispatch.h>
25
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_compositeexplicitautograd_dispatch.h>
26
+ #include <ATen/ops/_amp_update_scale_compositeexplicitautograd_dispatch.h>
27
+ #include <ATen/ops/_assert_scalar_compositeexplicitautograd_dispatch.h>
28
+ #include <ATen/ops/_batch_norm_no_update_compositeexplicitautograd_dispatch.h>
29
+ #include <ATen/ops/_batch_norm_with_update_compositeexplicitautograd_dispatch.h>
30
+ #include <ATen/ops/_cdist_backward_compositeexplicitautograd_dispatch.h>
31
+ #include <ATen/ops/_cdist_forward_compositeexplicitautograd_dispatch.h>
32
+ #include <ATen/ops/_cholesky_solve_helper_compositeexplicitautograd_dispatch.h>
33
+ #include <ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h>
34
+ #include <ATen/ops/_coalesce_compositeexplicitautograd_dispatch.h>
35
+ #include <ATen/ops/_coalesced_compositeexplicitautograd_dispatch.h>
36
+ #include <ATen/ops/_conj_compositeexplicitautograd_dispatch.h>
37
+ #include <ATen/ops/_conj_copy_compositeexplicitautograd_dispatch.h>
38
+ #include <ATen/ops/_conj_physical_compositeexplicitautograd_dispatch.h>
39
+ #include <ATen/ops/_convolution_compositeexplicitautograd_dispatch.h>
40
+ #include <ATen/ops/_copy_from_compositeexplicitautograd_dispatch.h>
41
+ #include <ATen/ops/_copy_from_and_resize_compositeexplicitautograd_dispatch.h>
42
+ #include <ATen/ops/_ctc_loss_compositeexplicitautograd_dispatch.h>
43
+ #include <ATen/ops/_ctc_loss_backward_compositeexplicitautograd_dispatch.h>
44
+ #include <ATen/ops/_cudnn_ctc_loss_compositeexplicitautograd_dispatch.h>
45
+ #include <ATen/ops/_cudnn_init_dropout_state_compositeexplicitautograd_dispatch.h>
46
+ #include <ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h>
47
+ #include <ATen/ops/_cudnn_rnn_backward_compositeexplicitautograd_dispatch.h>
48
+ #include <ATen/ops/_cudnn_rnn_flatten_weight_compositeexplicitautograd_dispatch.h>
49
+ #include <ATen/ops/_dirichlet_grad_compositeexplicitautograd_dispatch.h>
50
+ #include <ATen/ops/_efficientzerotensor_compositeexplicitautograd_dispatch.h>
51
+ #include <ATen/ops/_embedding_bag_compositeexplicitautograd_dispatch.h>
52
+ #include <ATen/ops/_embedding_bag_dense_backward_compositeexplicitautograd_dispatch.h>
53
+ #include <ATen/ops/_embedding_bag_forward_only_compositeexplicitautograd_dispatch.h>
54
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h>
55
+ #include <ATen/ops/_empty_affine_quantized_compositeexplicitautograd_dispatch.h>
56
+ #include <ATen/ops/_empty_per_channel_affine_quantized_compositeexplicitautograd_dispatch.h>
57
+ #include <ATen/ops/_euclidean_dist_compositeexplicitautograd_dispatch.h>
58
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_compositeexplicitautograd_dispatch.h>
59
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_compositeexplicitautograd_dispatch.h>
60
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_compositeexplicitautograd_dispatch.h>
61
+ #include <ATen/ops/_foobar_compositeexplicitautograd_dispatch.h>
62
+ #include <ATen/ops/_foreach_abs_compositeexplicitautograd_dispatch.h>
63
+ #include <ATen/ops/_foreach_acos_compositeexplicitautograd_dispatch.h>
64
+ #include <ATen/ops/_foreach_add_compositeexplicitautograd_dispatch.h>
65
+ #include <ATen/ops/_foreach_addcdiv_compositeexplicitautograd_dispatch.h>
66
+ #include <ATen/ops/_foreach_addcmul_compositeexplicitautograd_dispatch.h>
67
+ #include <ATen/ops/_foreach_asin_compositeexplicitautograd_dispatch.h>
68
+ #include <ATen/ops/_foreach_atan_compositeexplicitautograd_dispatch.h>
69
+ #include <ATen/ops/_foreach_ceil_compositeexplicitautograd_dispatch.h>
70
+ #include <ATen/ops/_foreach_clamp_max_compositeexplicitautograd_dispatch.h>
71
+ #include <ATen/ops/_foreach_clamp_min_compositeexplicitautograd_dispatch.h>
72
+ #include <ATen/ops/_foreach_copy_compositeexplicitautograd_dispatch.h>
73
+ #include <ATen/ops/_foreach_cos_compositeexplicitautograd_dispatch.h>
74
+ #include <ATen/ops/_foreach_cosh_compositeexplicitautograd_dispatch.h>
75
+ #include <ATen/ops/_foreach_div_compositeexplicitautograd_dispatch.h>
76
+ #include <ATen/ops/_foreach_erf_compositeexplicitautograd_dispatch.h>
77
+ #include <ATen/ops/_foreach_erfc_compositeexplicitautograd_dispatch.h>
78
+ #include <ATen/ops/_foreach_exp_compositeexplicitautograd_dispatch.h>
79
+ #include <ATen/ops/_foreach_expm1_compositeexplicitautograd_dispatch.h>
80
+ #include <ATen/ops/_foreach_floor_compositeexplicitautograd_dispatch.h>
81
+ #include <ATen/ops/_foreach_frac_compositeexplicitautograd_dispatch.h>
82
+ #include <ATen/ops/_foreach_lerp_compositeexplicitautograd_dispatch.h>
83
+ #include <ATen/ops/_foreach_lgamma_compositeexplicitautograd_dispatch.h>
84
+ #include <ATen/ops/_foreach_log_compositeexplicitautograd_dispatch.h>
85
+ #include <ATen/ops/_foreach_log10_compositeexplicitautograd_dispatch.h>
86
+ #include <ATen/ops/_foreach_log1p_compositeexplicitautograd_dispatch.h>
87
+ #include <ATen/ops/_foreach_log2_compositeexplicitautograd_dispatch.h>
88
+ #include <ATen/ops/_foreach_max_compositeexplicitautograd_dispatch.h>
89
+ #include <ATen/ops/_foreach_maximum_compositeexplicitautograd_dispatch.h>
90
+ #include <ATen/ops/_foreach_minimum_compositeexplicitautograd_dispatch.h>
91
+ #include <ATen/ops/_foreach_mul_compositeexplicitautograd_dispatch.h>
92
+ #include <ATen/ops/_foreach_neg_compositeexplicitautograd_dispatch.h>
93
+ #include <ATen/ops/_foreach_norm_compositeexplicitautograd_dispatch.h>
94
+ #include <ATen/ops/_foreach_pow_compositeexplicitautograd_dispatch.h>
95
+ #include <ATen/ops/_foreach_reciprocal_compositeexplicitautograd_dispatch.h>
96
+ #include <ATen/ops/_foreach_round_compositeexplicitautograd_dispatch.h>
97
+ #include <ATen/ops/_foreach_sigmoid_compositeexplicitautograd_dispatch.h>
98
+ #include <ATen/ops/_foreach_sign_compositeexplicitautograd_dispatch.h>
99
+ #include <ATen/ops/_foreach_sin_compositeexplicitautograd_dispatch.h>
100
+ #include <ATen/ops/_foreach_sinh_compositeexplicitautograd_dispatch.h>
101
+ #include <ATen/ops/_foreach_sqrt_compositeexplicitautograd_dispatch.h>
102
+ #include <ATen/ops/_foreach_sub_compositeexplicitautograd_dispatch.h>
103
+ #include <ATen/ops/_foreach_tan_compositeexplicitautograd_dispatch.h>
104
+ #include <ATen/ops/_foreach_tanh_compositeexplicitautograd_dispatch.h>
105
+ #include <ATen/ops/_foreach_trunc_compositeexplicitautograd_dispatch.h>
106
+ #include <ATen/ops/_foreach_zero_compositeexplicitautograd_dispatch.h>
107
+ #include <ATen/ops/_functional_assert_scalar_compositeexplicitautograd_dispatch.h>
108
+ #include <ATen/ops/_functional_sym_constrain_range_compositeexplicitautograd_dispatch.h>
109
+ #include <ATen/ops/_functional_sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
110
+ #include <ATen/ops/_fused_adagrad_compositeexplicitautograd_dispatch.h>
111
+ #include <ATen/ops/_fused_adam_compositeexplicitautograd_dispatch.h>
112
+ #include <ATen/ops/_fused_adamw_compositeexplicitautograd_dispatch.h>
113
+ #include <ATen/ops/_fused_dropout_compositeexplicitautograd_dispatch.h>
114
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper_compositeexplicitautograd_dispatch.h>
115
+ #include <ATen/ops/_fused_sgd_compositeexplicitautograd_dispatch.h>
116
+ #include <ATen/ops/_fw_primal_compositeexplicitautograd_dispatch.h>
117
+ #include <ATen/ops/_fw_primal_copy_compositeexplicitautograd_dispatch.h>
118
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_compositeexplicitautograd_dispatch.h>
119
+ #include <ATen/ops/_has_same_storage_numel_compositeexplicitautograd_dispatch.h>
120
+ #include <ATen/ops/_histogramdd_bin_edges_compositeexplicitautograd_dispatch.h>
121
+ #include <ATen/ops/_histogramdd_from_bin_cts_compositeexplicitautograd_dispatch.h>
122
+ #include <ATen/ops/_histogramdd_from_bin_tensors_compositeexplicitautograd_dispatch.h>
123
+ #include <ATen/ops/_index_put_impl_compositeexplicitautograd_dispatch.h>
124
+ #include <ATen/ops/_indices_copy_compositeexplicitautograd_dispatch.h>
125
+ #include <ATen/ops/_is_all_true_compositeexplicitautograd_dispatch.h>
126
+ #include <ATen/ops/_is_any_true_compositeexplicitautograd_dispatch.h>
127
+ #include <ATen/ops/_lazy_clone_compositeexplicitautograd_dispatch.h>
128
+ #include <ATen/ops/_linalg_check_errors_compositeexplicitautograd_dispatch.h>
129
+ #include <ATen/ops/_lstm_mps_compositeexplicitautograd_dispatch.h>
130
+ #include <ATen/ops/_make_dual_compositeexplicitautograd_dispatch.h>
131
+ #include <ATen/ops/_make_dual_copy_compositeexplicitautograd_dispatch.h>
132
+ #include <ATen/ops/_make_per_channel_quantized_tensor_compositeexplicitautograd_dispatch.h>
133
+ #include <ATen/ops/_make_per_tensor_quantized_tensor_compositeexplicitautograd_dispatch.h>
134
+ #include <ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h>
135
+ #include <ATen/ops/_masked_softmax_compositeexplicitautograd_dispatch.h>
136
+ #include <ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h>
137
+ #include <ATen/ops/_mkldnn_reshape_compositeexplicitautograd_dispatch.h>
138
+ #include <ATen/ops/_mkldnn_transpose_compositeexplicitautograd_dispatch.h>
139
+ #include <ATen/ops/_mps_convolution_compositeexplicitautograd_dispatch.h>
140
+ #include <ATen/ops/_mps_convolution_transpose_compositeexplicitautograd_dispatch.h>
141
+ #include <ATen/ops/_native_batch_norm_legit_compositeexplicitautograd_dispatch.h>
142
+ #include <ATen/ops/_native_batch_norm_legit_no_training_compositeexplicitautograd_dispatch.h>
143
+ #include <ATen/ops/_native_multi_head_attention_compositeexplicitautograd_dispatch.h>
144
+ #include <ATen/ops/_neg_view_compositeexplicitautograd_dispatch.h>
145
+ #include <ATen/ops/_neg_view_copy_compositeexplicitautograd_dispatch.h>
146
+ #include <ATen/ops/_nested_from_padded_compositeexplicitautograd_dispatch.h>
147
+ #include <ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h>
148
+ #include <ATen/ops/_nested_get_values_copy_compositeexplicitautograd_dispatch.h>
149
+ #include <ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h>
150
+ #include <ATen/ops/_nested_tensor_from_tensor_list_compositeexplicitautograd_dispatch.h>
151
+ #include <ATen/ops/_nested_tensor_size_compositeexplicitautograd_dispatch.h>
152
+ #include <ATen/ops/_nested_tensor_storage_offsets_compositeexplicitautograd_dispatch.h>
153
+ #include <ATen/ops/_nested_tensor_strides_compositeexplicitautograd_dispatch.h>
154
+ #include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautograd_dispatch.h>
155
+ #include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautograd_dispatch.h>
156
+ #include <ATen/ops/_new_zeros_with_same_feature_meta_compositeexplicitautograd_dispatch.h>
157
+ #include <ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h>
158
+ #include <ATen/ops/_pack_padded_sequence_compositeexplicitautograd_dispatch.h>
159
+ #include <ATen/ops/_pdist_backward_compositeexplicitautograd_dispatch.h>
160
+ #include <ATen/ops/_pdist_forward_compositeexplicitautograd_dispatch.h>
161
+ #include <ATen/ops/_pin_memory_compositeexplicitautograd_dispatch.h>
162
+ #include <ATen/ops/_print_compositeexplicitautograd_dispatch.h>
163
+ #include <ATen/ops/_reshape_alias_copy_compositeexplicitautograd_dispatch.h>
164
+ #include <ATen/ops/_reshape_copy_compositeexplicitautograd_dispatch.h>
165
+ #include <ATen/ops/_resize_output_compositeexplicitautograd_dispatch.h>
166
+ #include <ATen/ops/_safe_softmax_compositeexplicitautograd_dispatch.h>
167
+ #include <ATen/ops/_sample_dirichlet_compositeexplicitautograd_dispatch.h>
168
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_compositeexplicitautograd_dispatch.h>
169
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_compositeexplicitautograd_dispatch.h>
170
+ #include <ATen/ops/_segment_reduce_backward_compositeexplicitautograd_dispatch.h>
171
+ #include <ATen/ops/_slow_conv2d_backward_compositeexplicitautograd_dispatch.h>
172
+ #include <ATen/ops/_sparse_addmm_compositeexplicitautograd_dispatch.h>
173
+ #include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautograd_dispatch.h>
174
+ #include <ATen/ops/_sparse_compressed_tensor_with_dims_compositeexplicitautograd_dispatch.h>
175
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_compositeexplicitautograd_dispatch.h>
176
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_compositeexplicitautograd_dispatch.h>
177
+ #include <ATen/ops/_sparse_csr_prod_compositeexplicitautograd_dispatch.h>
178
+ #include <ATen/ops/_sparse_csr_sum_compositeexplicitautograd_dispatch.h>
179
+ #include <ATen/ops/_sparse_log_softmax_compositeexplicitautograd_dispatch.h>
180
+ #include <ATen/ops/_sparse_log_softmax_backward_data_compositeexplicitautograd_dispatch.h>
181
+ #include <ATen/ops/_sparse_mask_projection_compositeexplicitautograd_dispatch.h>
182
+ #include <ATen/ops/_sparse_softmax_compositeexplicitautograd_dispatch.h>
183
+ #include <ATen/ops/_sparse_softmax_backward_data_compositeexplicitautograd_dispatch.h>
184
+ #include <ATen/ops/_sparse_sparse_matmul_compositeexplicitautograd_dispatch.h>
185
+ #include <ATen/ops/_sparse_sum_compositeexplicitautograd_dispatch.h>
186
+ #include <ATen/ops/_sparse_sum_backward_compositeexplicitautograd_dispatch.h>
187
+ #include <ATen/ops/_spdiags_compositeexplicitautograd_dispatch.h>
188
+ #include <ATen/ops/_stack_compositeexplicitautograd_dispatch.h>
189
+ #include <ATen/ops/_standard_gamma_compositeexplicitautograd_dispatch.h>
190
+ #include <ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h>
191
+ #include <ATen/ops/_test_autograd_multiple_dispatch_compositeexplicitautograd_dispatch.h>
192
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_compositeexplicitautograd_dispatch.h>
193
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautograd_dispatch.h>
194
+ #include <ATen/ops/_test_functorch_fallback_compositeexplicitautograd_dispatch.h>
195
+ #include <ATen/ops/_test_optional_filled_intlist_compositeexplicitautograd_dispatch.h>
196
+ #include <ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h>
197
+ #include <ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h>
198
+ #include <ATen/ops/_test_parallel_materialize_compositeexplicitautograd_dispatch.h>
199
+ #include <ATen/ops/_test_warn_in_autograd_compositeexplicitautograd_dispatch.h>
200
+ #include <ATen/ops/_thnn_fused_gru_cell_compositeexplicitautograd_dispatch.h>
201
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_compositeexplicitautograd_dispatch.h>
202
+ #include <ATen/ops/_thnn_fused_lstm_cell_compositeexplicitautograd_dispatch.h>
203
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_compositeexplicitautograd_dispatch.h>
204
+ #include <ATen/ops/_to_copy_compositeexplicitautograd_dispatch.h>
205
+ #include <ATen/ops/_to_dense_compositeexplicitautograd_dispatch.h>
206
+ #include <ATen/ops/_to_sparse_compositeexplicitautograd_dispatch.h>
207
+ #include <ATen/ops/_to_sparse_bsc_compositeexplicitautograd_dispatch.h>
208
+ #include <ATen/ops/_to_sparse_bsr_compositeexplicitautograd_dispatch.h>
209
+ #include <ATen/ops/_to_sparse_csc_compositeexplicitautograd_dispatch.h>
210
+ #include <ATen/ops/_to_sparse_csr_compositeexplicitautograd_dispatch.h>
211
+ #include <ATen/ops/_transform_bias_rescale_qkv_compositeexplicitautograd_dispatch.h>
212
+ #include <ATen/ops/_transformer_encoder_layer_fwd_compositeexplicitautograd_dispatch.h>
213
+ #include <ATen/ops/_trilinear_compositeexplicitautograd_dispatch.h>
214
+ #include <ATen/ops/_triton_multi_head_attention_compositeexplicitautograd_dispatch.h>
215
+ #include <ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h>
216
+ #include <ATen/ops/_unique_compositeexplicitautograd_dispatch.h>
217
+ #include <ATen/ops/_unique2_compositeexplicitautograd_dispatch.h>
218
+ #include <ATen/ops/_unsafe_index_compositeexplicitautograd_dispatch.h>
219
+ #include <ATen/ops/_unsafe_index_put_compositeexplicitautograd_dispatch.h>
220
+ #include <ATen/ops/_unsafe_masked_index_compositeexplicitautograd_dispatch.h>
221
+ #include <ATen/ops/_unsafe_masked_index_put_accumulate_compositeexplicitautograd_dispatch.h>
222
+ #include <ATen/ops/_unsafe_view_compositeexplicitautograd_dispatch.h>
223
+ #include <ATen/ops/_values_copy_compositeexplicitautograd_dispatch.h>
224
+ #include <ATen/ops/_weight_norm_interface_compositeexplicitautograd_dispatch.h>
225
+ #include <ATen/ops/_weight_norm_interface_backward_compositeexplicitautograd_dispatch.h>
226
+ #include <ATen/ops/abs_compositeexplicitautograd_dispatch.h>
227
+ #include <ATen/ops/add_compositeexplicitautograd_dispatch.h>
228
+ #include <ATen/ops/addr_compositeexplicitautograd_dispatch.h>
229
+ #include <ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h>
230
+ #include <ATen/ops/alias_compositeexplicitautograd_dispatch.h>
231
+ #include <ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h>
232
+ #include <ATen/ops/all_compositeexplicitautograd_dispatch.h>
233
+ #include <ATen/ops/allclose_compositeexplicitautograd_dispatch.h>
234
+ #include <ATen/ops/any_compositeexplicitautograd_dispatch.h>
235
+ #include <ATen/ops/arange_compositeexplicitautograd_dispatch.h>
236
+ #include <ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h>
237
+ #include <ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h>
238
+ #include <ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h>
239
+ #include <ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h>
240
+ #include <ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h>
241
+ #include <ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h>
242
+ #include <ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h>
243
+ #include <ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h>
244
+ #include <ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h>
245
+ #include <ATen/ops/bernoulli_compositeexplicitautograd_dispatch.h>
246
+ #include <ATen/ops/binary_cross_entropy_with_logits_compositeexplicitautograd_dispatch.h>
247
+ #include <ATen/ops/bincount_compositeexplicitautograd_dispatch.h>
248
+ #include <ATen/ops/binomial_compositeexplicitautograd_dispatch.h>
249
+ #include <ATen/ops/bitwise_and_compositeexplicitautograd_dispatch.h>
250
+ #include <ATen/ops/bitwise_left_shift_compositeexplicitautograd_dispatch.h>
251
+ #include <ATen/ops/bitwise_or_compositeexplicitautograd_dispatch.h>
252
+ #include <ATen/ops/bitwise_right_shift_compositeexplicitautograd_dispatch.h>
253
+ #include <ATen/ops/bitwise_xor_compositeexplicitautograd_dispatch.h>
254
+ #include <ATen/ops/blackman_window_compositeexplicitautograd_dispatch.h>
255
+ #include <ATen/ops/block_diag_compositeexplicitautograd_dispatch.h>
256
+ #include <ATen/ops/bucketize_compositeexplicitautograd_dispatch.h>
257
+ #include <ATen/ops/cauchy_compositeexplicitautograd_dispatch.h>
258
+ #include <ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h>
259
+ #include <ATen/ops/ccol_indices_copy_compositeexplicitautograd_dispatch.h>
260
+ #include <ATen/ops/celu_compositeexplicitautograd_dispatch.h>
261
+ #include <ATen/ops/channel_shuffle_compositeexplicitautograd_dispatch.h>
262
+ #include <ATen/ops/cholesky_solve_compositeexplicitautograd_dispatch.h>
263
+ #include <ATen/ops/clone_compositeexplicitautograd_dispatch.h>
264
+ #include <ATen/ops/col_indices_compositeexplicitautograd_dispatch.h>
265
+ #include <ATen/ops/col_indices_copy_compositeexplicitautograd_dispatch.h>
266
+ #include <ATen/ops/complex_compositeexplicitautograd_dispatch.h>
267
+ #include <ATen/ops/conj_physical_compositeexplicitautograd_dispatch.h>
268
+ #include <ATen/ops/constant_pad_nd_compositeexplicitautograd_dispatch.h>
269
+ #include <ATen/ops/conv_depthwise3d_compositeexplicitautograd_dispatch.h>
270
+ #include <ATen/ops/conv_tbc_compositeexplicitautograd_dispatch.h>
271
+ #include <ATen/ops/convolution_compositeexplicitautograd_dispatch.h>
272
+ #include <ATen/ops/convolution_backward_compositeexplicitautograd_dispatch.h>
273
+ #include <ATen/ops/convolution_backward_overrideable_compositeexplicitautograd_dispatch.h>
274
+ #include <ATen/ops/convolution_overrideable_compositeexplicitautograd_dispatch.h>
275
+ #include <ATen/ops/copy_compositeexplicitautograd_dispatch.h>
276
+ #include <ATen/ops/copy_sparse_to_sparse_compositeexplicitautograd_dispatch.h>
277
+ #include <ATen/ops/copysign_compositeexplicitautograd_dispatch.h>
278
+ #include <ATen/ops/count_nonzero_compositeexplicitautograd_dispatch.h>
279
+ #include <ATen/ops/crow_indices_compositeexplicitautograd_dispatch.h>
280
+ #include <ATen/ops/crow_indices_copy_compositeexplicitautograd_dispatch.h>
281
+ #include <ATen/ops/cudnn_affine_grid_generator_compositeexplicitautograd_dispatch.h>
282
+ #include <ATen/ops/cudnn_affine_grid_generator_backward_compositeexplicitautograd_dispatch.h>
283
+ #include <ATen/ops/cudnn_batch_norm_compositeexplicitautograd_dispatch.h>
284
+ #include <ATen/ops/cudnn_batch_norm_backward_compositeexplicitautograd_dispatch.h>
285
+ #include <ATen/ops/cudnn_convolution_add_relu_compositeexplicitautograd_dispatch.h>
286
+ #include <ATen/ops/cudnn_convolution_relu_compositeexplicitautograd_dispatch.h>
287
+ #include <ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h>
288
+ #include <ATen/ops/cudnn_grid_sampler_compositeexplicitautograd_dispatch.h>
289
+ #include <ATen/ops/cudnn_grid_sampler_backward_compositeexplicitautograd_dispatch.h>
290
+ #include <ATen/ops/cummax_compositeexplicitautograd_dispatch.h>
291
+ #include <ATen/ops/cummin_compositeexplicitautograd_dispatch.h>
292
+ #include <ATen/ops/deg2rad_compositeexplicitautograd_dispatch.h>
293
+ #include <ATen/ops/dense_dim_compositeexplicitautograd_dispatch.h>
294
+ #include <ATen/ops/dequantize_compositeexplicitautograd_dispatch.h>
295
+ #include <ATen/ops/detach_compositeexplicitautograd_dispatch.h>
296
+ #include <ATen/ops/detach_copy_compositeexplicitautograd_dispatch.h>
297
+ #include <ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h>
298
+ #include <ATen/ops/diagonal_compositeexplicitautograd_dispatch.h>
299
+ #include <ATen/ops/diagonal_backward_compositeexplicitautograd_dispatch.h>
300
+ #include <ATen/ops/diagonal_copy_compositeexplicitautograd_dispatch.h>
301
+ #include <ATen/ops/diagonal_scatter_compositeexplicitautograd_dispatch.h>
302
+ #include <ATen/ops/dist_compositeexplicitautograd_dispatch.h>
303
+ #include <ATen/ops/div_compositeexplicitautograd_dispatch.h>
304
+ #include <ATen/ops/dot_compositeexplicitautograd_dispatch.h>
305
+ #include <ATen/ops/embedding_compositeexplicitautograd_dispatch.h>
306
+ #include <ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h>
307
+ #include <ATen/ops/embedding_renorm_compositeexplicitautograd_dispatch.h>
308
+ #include <ATen/ops/empty_compositeexplicitautograd_dispatch.h>
309
+ #include <ATen/ops/empty_like_compositeexplicitautograd_dispatch.h>
310
+ #include <ATen/ops/empty_permuted_compositeexplicitautograd_dispatch.h>
311
+ #include <ATen/ops/empty_quantized_compositeexplicitautograd_dispatch.h>
312
+ #include <ATen/ops/empty_strided_compositeexplicitautograd_dispatch.h>
313
+ #include <ATen/ops/expand_compositeexplicitautograd_dispatch.h>
314
+ #include <ATen/ops/expand_copy_compositeexplicitautograd_dispatch.h>
315
+ #include <ATen/ops/exponential_compositeexplicitautograd_dispatch.h>
316
+ #include <ATen/ops/eye_compositeexplicitautograd_dispatch.h>
317
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_compositeexplicitautograd_dispatch.h>
318
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_compositeexplicitautograd_dispatch.h>
319
+ #include <ATen/ops/fft_fftfreq_compositeexplicitautograd_dispatch.h>
320
+ #include <ATen/ops/fft_rfftfreq_compositeexplicitautograd_dispatch.h>
321
+ #include <ATen/ops/fill_compositeexplicitautograd_dispatch.h>
322
+ #include <ATen/ops/flip_compositeexplicitautograd_dispatch.h>
323
+ #include <ATen/ops/floor_divide_compositeexplicitautograd_dispatch.h>
324
+ #include <ATen/ops/fmod_compositeexplicitautograd_dispatch.h>
325
+ #include <ATen/ops/frexp_compositeexplicitautograd_dispatch.h>
326
+ #include <ATen/ops/from_file_compositeexplicitautograd_dispatch.h>
327
+ #include <ATen/ops/full_compositeexplicitautograd_dispatch.h>
328
+ #include <ATen/ops/full_like_compositeexplicitautograd_dispatch.h>
329
+ #include <ATen/ops/geometric_compositeexplicitautograd_dispatch.h>
330
+ #include <ATen/ops/glu_backward_jvp_compositeexplicitautograd_dispatch.h>
331
+ #include <ATen/ops/glu_jvp_compositeexplicitautograd_dispatch.h>
332
+ #include <ATen/ops/grid_sampler_2d_compositeexplicitautograd_dispatch.h>
333
+ #include <ATen/ops/grid_sampler_2d_backward_compositeexplicitautograd_dispatch.h>
334
+ #include <ATen/ops/grid_sampler_3d_compositeexplicitautograd_dispatch.h>
335
+ #include <ATen/ops/grid_sampler_3d_backward_compositeexplicitautograd_dispatch.h>
336
+ #include <ATen/ops/hamming_window_compositeexplicitautograd_dispatch.h>
337
+ #include <ATen/ops/hann_window_compositeexplicitautograd_dispatch.h>
338
+ #include <ATen/ops/hardswish_backward_compositeexplicitautograd_dispatch.h>
339
+ #include <ATen/ops/huber_loss_backward_compositeexplicitautograd_dispatch.h>
340
+ #include <ATen/ops/index_fill_compositeexplicitautograd_dispatch.h>
341
+ #include <ATen/ops/index_put_compositeexplicitautograd_dispatch.h>
342
+ #include <ATen/ops/indices_compositeexplicitautograd_dispatch.h>
343
+ #include <ATen/ops/indices_copy_compositeexplicitautograd_dispatch.h>
344
+ #include <ATen/ops/int_repr_compositeexplicitautograd_dispatch.h>
345
+ #include <ATen/ops/is_coalesced_compositeexplicitautograd_dispatch.h>
346
+ #include <ATen/ops/is_pinned_compositeexplicitautograd_dispatch.h>
347
+ #include <ATen/ops/is_same_size_compositeexplicitautograd_dispatch.h>
348
+ #include <ATen/ops/isinf_compositeexplicitautograd_dispatch.h>
349
+ #include <ATen/ops/isnan_compositeexplicitautograd_dispatch.h>
350
+ #include <ATen/ops/kaiser_window_compositeexplicitautograd_dispatch.h>
351
+ #include <ATen/ops/kthvalue_compositeexplicitautograd_dispatch.h>
352
+ #include <ATen/ops/lift_compositeexplicitautograd_dispatch.h>
353
+ #include <ATen/ops/lift_fresh_compositeexplicitautograd_dispatch.h>
354
+ #include <ATen/ops/lift_fresh_copy_compositeexplicitautograd_dispatch.h>
355
+ #include <ATen/ops/linalg_lstsq_compositeexplicitautograd_dispatch.h>
356
+ #include <ATen/ops/linalg_matrix_exp_compositeexplicitautograd_dispatch.h>
357
+ #include <ATen/ops/linalg_pinv_compositeexplicitautograd_dispatch.h>
358
+ #include <ATen/ops/linear_compositeexplicitautograd_dispatch.h>
359
+ #include <ATen/ops/linear_backward_compositeexplicitautograd_dispatch.h>
360
+ #include <ATen/ops/linspace_compositeexplicitautograd_dispatch.h>
361
+ #include <ATen/ops/log_normal_compositeexplicitautograd_dispatch.h>
362
+ #include <ATen/ops/log_softmax_compositeexplicitautograd_dispatch.h>
363
+ #include <ATen/ops/logcumsumexp_compositeexplicitautograd_dispatch.h>
364
+ #include <ATen/ops/logical_and_compositeexplicitautograd_dispatch.h>
365
+ #include <ATen/ops/logical_not_compositeexplicitautograd_dispatch.h>
366
+ #include <ATen/ops/logical_or_compositeexplicitautograd_dispatch.h>
367
+ #include <ATen/ops/logical_xor_compositeexplicitautograd_dispatch.h>
368
+ #include <ATen/ops/logspace_compositeexplicitautograd_dispatch.h>
369
+ #include <ATen/ops/logsumexp_compositeexplicitautograd_dispatch.h>
370
+ #include <ATen/ops/lshift_compositeexplicitautograd_dispatch.h>
371
+ #include <ATen/ops/lstm_mps_backward_compositeexplicitautograd_dispatch.h>
372
+ #include <ATen/ops/masked_fill_compositeexplicitautograd_dispatch.h>
373
+ #include <ATen/ops/masked_scatter_compositeexplicitautograd_dispatch.h>
374
+ #include <ATen/ops/masked_scatter_backward_compositeexplicitautograd_dispatch.h>
375
+ #include <ATen/ops/matmul_backward_compositeexplicitautograd_dispatch.h>
376
+ #include <ATen/ops/max_pool2d_backward_compositeexplicitautograd_dispatch.h>
377
+ #include <ATen/ops/mean_compositeexplicitautograd_dispatch.h>
378
+ #include <ATen/ops/median_compositeexplicitautograd_dispatch.h>
379
+ #include <ATen/ops/miopen_batch_norm_compositeexplicitautograd_dispatch.h>
380
+ #include <ATen/ops/miopen_batch_norm_backward_compositeexplicitautograd_dispatch.h>
381
+ #include <ATen/ops/miopen_convolution_compositeexplicitautograd_dispatch.h>
382
+ #include <ATen/ops/miopen_convolution_transpose_compositeexplicitautograd_dispatch.h>
383
+ #include <ATen/ops/miopen_depthwise_convolution_compositeexplicitautograd_dispatch.h>
384
+ #include <ATen/ops/miopen_rnn_compositeexplicitautograd_dispatch.h>
385
+ #include <ATen/ops/miopen_rnn_backward_compositeexplicitautograd_dispatch.h>
386
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
387
+ #include <ATen/ops/mkldnn_convolution_compositeexplicitautograd_dispatch.h>
388
+ #include <ATen/ops/mkldnn_linear_compositeexplicitautograd_dispatch.h>
389
+ #include <ATen/ops/mkldnn_linear_backward_compositeexplicitautograd_dispatch.h>
390
+ #include <ATen/ops/mkldnn_linear_backward_input_compositeexplicitautograd_dispatch.h>
391
+ #include <ATen/ops/mkldnn_linear_backward_weights_compositeexplicitautograd_dispatch.h>
392
+ #include <ATen/ops/mkldnn_max_pool2d_compositeexplicitautograd_dispatch.h>
393
+ #include <ATen/ops/mkldnn_max_pool2d_backward_compositeexplicitautograd_dispatch.h>
394
+ #include <ATen/ops/mkldnn_max_pool3d_compositeexplicitautograd_dispatch.h>
395
+ #include <ATen/ops/mkldnn_max_pool3d_backward_compositeexplicitautograd_dispatch.h>
396
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight_compositeexplicitautograd_dispatch.h>
397
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight_compositeexplicitautograd_dispatch.h>
398
+ #include <ATen/ops/mkldnn_rnn_layer_compositeexplicitautograd_dispatch.h>
399
+ #include <ATen/ops/mkldnn_rnn_layer_backward_compositeexplicitautograd_dispatch.h>
400
+ #include <ATen/ops/mode_compositeexplicitautograd_dispatch.h>
401
+ #include <ATen/ops/mps_convolution_backward_compositeexplicitautograd_dispatch.h>
402
+ #include <ATen/ops/mps_convolution_transpose_backward_compositeexplicitautograd_dispatch.h>
403
+ #include <ATen/ops/mul_compositeexplicitautograd_dispatch.h>
404
+ #include <ATen/ops/mv_compositeexplicitautograd_dispatch.h>
405
+ #include <ATen/ops/mvlgamma_compositeexplicitautograd_dispatch.h>
406
+ #include <ATen/ops/nan_to_num_compositeexplicitautograd_dispatch.h>
407
+ #include <ATen/ops/nanmedian_compositeexplicitautograd_dispatch.h>
408
+ #include <ATen/ops/native_batch_norm_backward_compositeexplicitautograd_dispatch.h>
409
+ #include <ATen/ops/native_dropout_compositeexplicitautograd_dispatch.h>
410
+ #include <ATen/ops/native_dropout_backward_compositeexplicitautograd_dispatch.h>
411
+ #include <ATen/ops/native_group_norm_compositeexplicitautograd_dispatch.h>
412
+ #include <ATen/ops/native_group_norm_backward_compositeexplicitautograd_dispatch.h>
413
+ #include <ATen/ops/native_layer_norm_compositeexplicitautograd_dispatch.h>
414
+ #include <ATen/ops/native_layer_norm_backward_compositeexplicitautograd_dispatch.h>
415
+ #include <ATen/ops/native_norm_compositeexplicitautograd_dispatch.h>
416
+ #include <ATen/ops/new_empty_compositeexplicitautograd_dispatch.h>
417
+ #include <ATen/ops/new_empty_strided_compositeexplicitautograd_dispatch.h>
418
+ #include <ATen/ops/new_full_compositeexplicitautograd_dispatch.h>
419
+ #include <ATen/ops/new_ones_compositeexplicitautograd_dispatch.h>
420
+ #include <ATen/ops/new_zeros_compositeexplicitautograd_dispatch.h>
421
+ #include <ATen/ops/norm_compositeexplicitautograd_dispatch.h>
422
+ #include <ATen/ops/normal_compositeexplicitautograd_dispatch.h>
423
+ #include <ATen/ops/ones_compositeexplicitautograd_dispatch.h>
424
+ #include <ATen/ops/ones_like_compositeexplicitautograd_dispatch.h>
425
+ #include <ATen/ops/permute_compositeexplicitautograd_dispatch.h>
426
+ #include <ATen/ops/permute_copy_compositeexplicitautograd_dispatch.h>
427
+ #include <ATen/ops/pixel_shuffle_compositeexplicitautograd_dispatch.h>
428
+ #include <ATen/ops/pixel_unshuffle_compositeexplicitautograd_dispatch.h>
429
+ #include <ATen/ops/poisson_compositeexplicitautograd_dispatch.h>
430
+ #include <ATen/ops/polar_compositeexplicitautograd_dispatch.h>
431
+ #include <ATen/ops/polygamma_compositeexplicitautograd_dispatch.h>
432
+ #include <ATen/ops/prod_compositeexplicitautograd_dispatch.h>
433
+ #include <ATen/ops/put_compositeexplicitautograd_dispatch.h>
434
+ #include <ATen/ops/q_per_channel_scales_compositeexplicitautograd_dispatch.h>
435
+ #include <ATen/ops/q_per_channel_zero_points_compositeexplicitautograd_dispatch.h>
436
+ #include <ATen/ops/quantize_per_channel_compositeexplicitautograd_dispatch.h>
437
+ #include <ATen/ops/quantize_per_tensor_compositeexplicitautograd_dispatch.h>
438
+ #include <ATen/ops/quantize_per_tensor_dynamic_compositeexplicitautograd_dispatch.h>
439
+ #include <ATen/ops/quantized_batch_norm_compositeexplicitautograd_dispatch.h>
440
+ #include <ATen/ops/quantized_max_pool1d_compositeexplicitautograd_dispatch.h>
441
+ #include <ATen/ops/quantized_max_pool2d_compositeexplicitautograd_dispatch.h>
442
+ #include <ATen/ops/quantized_max_pool3d_compositeexplicitautograd_dispatch.h>
443
+ #include <ATen/ops/rad2deg_compositeexplicitautograd_dispatch.h>
444
+ #include <ATen/ops/rand_compositeexplicitautograd_dispatch.h>
445
+ #include <ATen/ops/rand_like_compositeexplicitautograd_dispatch.h>
446
+ #include <ATen/ops/randint_compositeexplicitautograd_dispatch.h>
447
+ #include <ATen/ops/randint_like_compositeexplicitautograd_dispatch.h>
448
+ #include <ATen/ops/randn_compositeexplicitautograd_dispatch.h>
449
+ #include <ATen/ops/randn_like_compositeexplicitautograd_dispatch.h>
450
+ #include <ATen/ops/random_compositeexplicitautograd_dispatch.h>
451
+ #include <ATen/ops/randperm_compositeexplicitautograd_dispatch.h>
452
+ #include <ATen/ops/range_compositeexplicitautograd_dispatch.h>
453
+ #include <ATen/ops/relu_compositeexplicitautograd_dispatch.h>
454
+ #include <ATen/ops/remainder_compositeexplicitautograd_dispatch.h>
455
+ #include <ATen/ops/repeat_compositeexplicitautograd_dispatch.h>
456
+ #include <ATen/ops/repeat_interleave_compositeexplicitautograd_dispatch.h>
457
+ #include <ATen/ops/resize_compositeexplicitautograd_dispatch.h>
458
+ #include <ATen/ops/resize_as_compositeexplicitautograd_dispatch.h>
459
+ #include <ATen/ops/resize_as_sparse_compositeexplicitautograd_dispatch.h>
460
+ #include <ATen/ops/roll_compositeexplicitautograd_dispatch.h>
461
+ #include <ATen/ops/rot90_compositeexplicitautograd_dispatch.h>
462
+ #include <ATen/ops/row_indices_compositeexplicitautograd_dispatch.h>
463
+ #include <ATen/ops/row_indices_copy_compositeexplicitautograd_dispatch.h>
464
+ #include <ATen/ops/rrelu_with_noise_backward_compositeexplicitautograd_dispatch.h>
465
+ #include <ATen/ops/rshift_compositeexplicitautograd_dispatch.h>
466
+ #include <ATen/ops/rsub_compositeexplicitautograd_dispatch.h>
467
+ #include <ATen/ops/scalar_tensor_compositeexplicitautograd_dispatch.h>
468
+ #include <ATen/ops/segment_reduce_compositeexplicitautograd_dispatch.h>
469
+ #include <ATen/ops/select_compositeexplicitautograd_dispatch.h>
470
+ #include <ATen/ops/select_backward_compositeexplicitautograd_dispatch.h>
471
+ #include <ATen/ops/select_copy_compositeexplicitautograd_dispatch.h>
472
+ #include <ATen/ops/select_scatter_compositeexplicitautograd_dispatch.h>
473
+ #include <ATen/ops/set_compositeexplicitautograd_dispatch.h>
474
+ #include <ATen/ops/slice_compositeexplicitautograd_dispatch.h>
475
+ #include <ATen/ops/slice_backward_compositeexplicitautograd_dispatch.h>
476
+ #include <ATen/ops/slice_copy_compositeexplicitautograd_dispatch.h>
477
+ #include <ATen/ops/slice_inverse_compositeexplicitautograd_dispatch.h>
478
+ #include <ATen/ops/slice_scatter_compositeexplicitautograd_dispatch.h>
479
+ #include <ATen/ops/slow_conv_dilated2d_compositeexplicitautograd_dispatch.h>
480
+ #include <ATen/ops/slow_conv_dilated3d_compositeexplicitautograd_dispatch.h>
481
+ #include <ATen/ops/smooth_l1_loss_backward_compositeexplicitautograd_dispatch.h>
482
+ #include <ATen/ops/soft_margin_loss_compositeexplicitautograd_dispatch.h>
483
+ #include <ATen/ops/soft_margin_loss_backward_compositeexplicitautograd_dispatch.h>
484
+ #include <ATen/ops/softmax_compositeexplicitautograd_dispatch.h>
485
+ #include <ATen/ops/sort_compositeexplicitautograd_dispatch.h>
486
+ #include <ATen/ops/sparse_compressed_tensor_compositeexplicitautograd_dispatch.h>
487
+ #include <ATen/ops/sparse_coo_tensor_compositeexplicitautograd_dispatch.h>
488
+ #include <ATen/ops/sparse_dim_compositeexplicitautograd_dispatch.h>
489
+ #include <ATen/ops/sparse_mask_compositeexplicitautograd_dispatch.h>
490
+ #include <ATen/ops/sparse_resize_compositeexplicitautograd_dispatch.h>
491
+ #include <ATen/ops/sparse_resize_and_clear_compositeexplicitautograd_dispatch.h>
492
+ #include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
493
+ #include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
494
+ #include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
495
+ #include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
496
+ #include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautograd_dispatch.h>
497
+ #include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautograd_dispatch.h>
498
+ #include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautograd_dispatch.h>
499
+ #include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautograd_dispatch.h>
500
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
501
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
502
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
503
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
504
+ #include <ATen/ops/special_xlog1py_compositeexplicitautograd_dispatch.h>
505
+ #include <ATen/ops/special_zeta_compositeexplicitautograd_dispatch.h>
506
+ #include <ATen/ops/split_compositeexplicitautograd_dispatch.h>
507
+ #include <ATen/ops/split_copy_compositeexplicitautograd_dispatch.h>
508
+ #include <ATen/ops/split_with_sizes_compositeexplicitautograd_dispatch.h>
509
+ #include <ATen/ops/split_with_sizes_copy_compositeexplicitautograd_dispatch.h>
510
+ #include <ATen/ops/squeeze_compositeexplicitautograd_dispatch.h>
511
+ #include <ATen/ops/squeeze_copy_compositeexplicitautograd_dispatch.h>
512
+ #include <ATen/ops/stack_compositeexplicitautograd_dispatch.h>
513
+ #include <ATen/ops/std_mean_compositeexplicitautograd_dispatch.h>
514
+ #include <ATen/ops/sub_compositeexplicitautograd_dispatch.h>
515
+ #include <ATen/ops/sum_compositeexplicitautograd_dispatch.h>
516
+ #include <ATen/ops/sym_constrain_range_compositeexplicitautograd_dispatch.h>
517
+ #include <ATen/ops/sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
518
+ #include <ATen/ops/t_compositeexplicitautograd_dispatch.h>
519
+ #include <ATen/ops/t_copy_compositeexplicitautograd_dispatch.h>
520
+ #include <ATen/ops/to_mkldnn_compositeexplicitautograd_dispatch.h>
521
+ #include <ATen/ops/to_padded_tensor_compositeexplicitautograd_dispatch.h>
522
+ #include <ATen/ops/trace_compositeexplicitautograd_dispatch.h>
523
+ #include <ATen/ops/transpose_compositeexplicitautograd_dispatch.h>
524
+ #include <ATen/ops/transpose_copy_compositeexplicitautograd_dispatch.h>
525
+ #include <ATen/ops/tril_indices_compositeexplicitautograd_dispatch.h>
526
+ #include <ATen/ops/triu_indices_compositeexplicitautograd_dispatch.h>
527
+ #include <ATen/ops/unbind_compositeexplicitautograd_dispatch.h>
528
+ #include <ATen/ops/unbind_copy_compositeexplicitautograd_dispatch.h>
529
+ #include <ATen/ops/unfold_backward_compositeexplicitautograd_dispatch.h>
530
+ #include <ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h>
531
+ #include <ATen/ops/uniform_compositeexplicitautograd_dispatch.h>
532
+ #include <ATen/ops/unique_consecutive_compositeexplicitautograd_dispatch.h>
533
+ #include <ATen/ops/unique_dim_compositeexplicitautograd_dispatch.h>
534
+ #include <ATen/ops/unique_dim_consecutive_compositeexplicitautograd_dispatch.h>
535
+ #include <ATen/ops/unsafe_split_compositeexplicitautograd_dispatch.h>
536
+ #include <ATen/ops/unsafe_split_with_sizes_compositeexplicitautograd_dispatch.h>
537
+ #include <ATen/ops/unsqueeze_compositeexplicitautograd_dispatch.h>
538
+ #include <ATen/ops/unsqueeze_copy_compositeexplicitautograd_dispatch.h>
539
+ #include <ATen/ops/values_compositeexplicitautograd_dispatch.h>
540
+ #include <ATen/ops/values_copy_compositeexplicitautograd_dispatch.h>
541
+ #include <ATen/ops/var_mean_compositeexplicitautograd_dispatch.h>
542
+ #include <ATen/ops/vdot_compositeexplicitautograd_dispatch.h>
543
+ #include <ATen/ops/view_compositeexplicitautograd_dispatch.h>
544
+ #include <ATen/ops/view_as_complex_copy_compositeexplicitautograd_dispatch.h>
545
+ #include <ATen/ops/view_as_real_copy_compositeexplicitautograd_dispatch.h>
546
+ #include <ATen/ops/view_copy_compositeexplicitautograd_dispatch.h>
547
+ #include <ATen/ops/xlogy_compositeexplicitautograd_dispatch.h>
548
+ #include <ATen/ops/zero_compositeexplicitautograd_dispatch.h>
549
+ #include <ATen/ops/zeros_compositeexplicitautograd_dispatch.h>
550
+ #include <ATen/ops/zeros_like_compositeexplicitautograd_dispatch.h>
551
+
552
+
553
+
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable std::optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ #include <ATen/CompositeImplicitAutogradFunctions_inl.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable std::optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ #include <ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ #include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
20
+ #include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
21
+ #include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
22
+ #include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
23
+
24
+
25
+
.venv/lib/python3.11/site-packages/torch/include/ATen/Config.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
4
+ // obvious if you forgot to include Config.h
5
+ // c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
6
+ //
7
+ // DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
8
+
9
+ #define AT_MKLDNN_ENABLED() 1
10
+ #define AT_MKLDNN_ACL_ENABLED() 0
11
+ #define AT_MKL_ENABLED() 1
12
+ #define AT_MKL_SEQUENTIAL() 0
13
+ #define AT_POCKETFFT_ENABLED() 0
14
+ #define AT_NNPACK_ENABLED() 1
15
+ #define CAFFE2_STATIC_LINK_CUDA() 0
16
+ #define AT_BUILD_WITH_BLAS() 1
17
+ #define AT_BUILD_WITH_LAPACK() 1
18
+ #define AT_PARALLEL_OPENMP 1
19
+ #define AT_PARALLEL_NATIVE 0
20
+ #define AT_BLAS_F2C() 0
21
+ #define AT_BLAS_USE_CBLAS_DOT() 0
.venv/lib/python3.11/site-packages/torch/include/ATen/Context.h ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/BlasBackend.h>
4
+ #include <ATen/CPUGeneratorImpl.h>
5
+ #include <ATen/DeviceAccelerator.h>
6
+ #include <ATen/LinalgBackend.h>
7
+ #include <ATen/core/ATenGeneral.h>
8
+ #include <ATen/core/DeprecatedTypeProperties.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/LegacyTypeDispatch.h>
11
+ #include <ATen/detail/AcceleratorHooksInterface.h>
12
+ #include <ATen/detail/CUDAHooksInterface.h>
13
+ #include <ATen/detail/HIPHooksInterface.h>
14
+ #include <ATen/detail/IPUHooksInterface.h>
15
+ #include <ATen/detail/MAIAHooksInterface.h>
16
+ #include <ATen/detail/MPSHooksInterface.h>
17
+ #include <ATen/detail/MTIAHooksInterface.h>
18
+ #include <ATen/detail/PrivateUse1HooksInterface.h>
19
+ #include <ATen/detail/XPUHooksInterface.h>
20
+ #include <c10/core/QEngine.h>
21
+ #include <c10/core/impl/DeviceGuardImplInterface.h>
22
+ #include <c10/util/CallOnce.h>
23
+ #include <c10/util/Exception.h>
24
+ #include <c10/util/env.h>
25
+ #include <c10/util/irange.h>
26
+
27
+ #include <cstdint>
28
+ #include <mutex>
29
+
30
+ namespace at {
31
+
32
+ class Tensor;
33
+
34
+ enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
35
+
36
+ class TORCH_API Context {
37
+ public:
38
+ Context();
39
+
40
+ const Generator& defaultGenerator(Device device) {
41
+ c10::DeviceType device_type = device.type();
42
+ initCUDAIfNeeded(device_type);
43
+ initHIPIfNeeded(device_type);
44
+ if (device_type == at::kCPU) {
45
+ return at::detail::getDefaultCPUGenerator();
46
+ } else if (device_type == at::kCUDA) {
47
+ return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
48
+ } else if (device_type == at::kMPS) {
49
+ return at::detail::getMPSHooks().getDefaultMPSGenerator();
50
+ } else if (device_type == at::kXPU) {
51
+ return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
52
+ } else if (device_type == at::kIPU) {
53
+ return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
54
+ } else if (device_type == at::kPrivateUse1) {
55
+ return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
56
+ device.index());
57
+ } else {
58
+ AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
59
+ }
60
+ }
61
+ const AcceleratorHooksInterface& getAcceleratorHooksInterface(
62
+ std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
63
+ c10::DeviceType device_type = opt_device_type.has_value()
64
+ ? opt_device_type.value()
65
+ : at::getAccelerator(true).value();
66
+ if (device_type == at::kCUDA) {
67
+ return at::detail::getCUDAHooks();
68
+ } else if (device_type == at::kXPU) {
69
+ return at::detail::getXPUHooks();
70
+ } else if (device_type == at::kMPS) {
71
+ return at::detail::getMPSHooks();
72
+ } else if (device_type == at::kPrivateUse1) {
73
+ return at::detail::getPrivateUse1Hooks();
74
+ } else if (device_type == at::kMTIA) {
75
+ return at::detail::getMTIAHooks();
76
+ } else if (device_type == at::kHIP) {
77
+ return at::detail::getHIPHooks();
78
+ } else {
79
+ AT_ERROR(
80
+ c10::DeviceTypeName(device_type), " device type not an accelerator.");
81
+ }
82
+ }
83
+ Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
84
+ initCUDAIfNeeded(device_type);
85
+ initHIPIfNeeded(device_type);
86
+ initXPUIfNeeded(device_type);
87
+ if (device_type == at::kCPU) {
88
+ return c10::DeviceType::CPU;
89
+ } else if (device_type == at::kCUDA) {
90
+ return at::detail::getCUDAHooks().getDeviceFromPtr(data);
91
+ } else if (device_type == at::kXPU) {
92
+ return at::detail::getXPUHooks().getDeviceFromPtr(data);
93
+ } else if (device_type == at::kPrivateUse1) {
94
+ return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
95
+ } else {
96
+ AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
97
+ }
98
+ }
99
+ bool isPinnedPtr(
100
+ const void* data,
101
+ std::optional<c10::DeviceType> device_type = std::nullopt) {
102
+ auto opt_device_type =
103
+ device_type.has_value() ? device_type : at::getAccelerator();
104
+ if (!opt_device_type.has_value() || // there is no accelerator
105
+ !at::isAccelerator(
106
+ opt_device_type.value())) { // passed device not an accelerator
107
+ return false;
108
+ }
109
+ return getAcceleratorHooksInterface(opt_device_type.value())
110
+ .isPinnedPtr(data);
111
+ }
112
+ Allocator* getPinnedMemoryAllocator(
113
+ std::optional<c10::DeviceType> device_type = std::nullopt) {
114
+ return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
115
+ }
116
+ static bool hasOpenMP();
117
+ static bool hasMKL();
118
+ static bool hasLAPACK();
119
+ static bool hasMKLDNN();
120
+ static bool hasMAGMA() {
121
+ return detail::getCUDAHooks().hasMAGMA();
122
+ }
123
+ static bool hasCUDA() {
124
+ return detail::getCUDAHooks().hasCUDA();
125
+ }
126
+ static bool hasMTIA() {
127
+ return detail::getMTIAHooks().hasMTIA();
128
+ }
129
+ static bool hasCUDART() {
130
+ return detail::getCUDAHooks().hasCUDART();
131
+ }
132
+ static long versionCUDART() {
133
+ return detail::getCUDAHooks().versionCUDART();
134
+ }
135
+ static bool hasCuDNN() {
136
+ return detail::getCUDAHooks().hasCuDNN();
137
+ }
138
+ static long versionCuDNN() {
139
+ return detail::getCUDAHooks().versionCuDNN();
140
+ }
141
+ static bool hasCuSOLVER() {
142
+ return detail::getCUDAHooks().hasCuSOLVER();
143
+ }
144
+ static bool hasCuBLASLt() {
145
+ return detail::getCUDAHooks().hasCuBLASLt();
146
+ }
147
+ static bool hasHIP() {
148
+ return detail::getHIPHooks().hasHIP();
149
+ }
150
+ static bool hasMPS() {
151
+ return detail::getMPSHooks().hasMPS();
152
+ }
153
+ static bool hasIPU() {
154
+ return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
155
+ }
156
+ static bool hasXLA() {
157
+ return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
158
+ }
159
+ static bool hasXPU() {
160
+ return detail::getXPUHooks().hasXPU();
161
+ }
162
+ static bool hasLazy() {
163
+ return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
164
+ }
165
+ static bool hasMAIA() {
166
+ return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
167
+ }
168
+ // defined in header so that getNonVariableType has ability to inline
169
+ // call_once check. getNonVariableType is called fairly frequently
170
+ void lazyInitCUDA() {
171
+ c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
172
+ }
173
+ void lazyInitHIP() {
174
+ c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
175
+ }
176
+ void lazyInitXPU() {
177
+ c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
178
+ }
179
+ void lazyInitMTIA() {
180
+ c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
181
+ }
182
+ void lazyInitPrivateUse1() {
183
+ c10::call_once(thp_init, [&] {
184
+ if (isPrivateUse1HooksRegistered()) {
185
+ at::detail::getPrivateUse1Hooks().initPrivateUse1();
186
+ }
187
+ });
188
+ }
189
+ static const at::cuda::NVRTC& getNVRTC() {
190
+ return detail::getCUDAHooks().nvrtc();
191
+ }
192
+
193
+ static bool setFlushDenormal(bool on);
194
+
195
+ // NB: This method is *purely* whether or not a user requested
196
+ // that CuDNN was enabled, it doesn't actually say anything about
197
+ // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
198
+ // to test this instead
199
+ bool userEnabledCuDNN() const;
200
+ void setUserEnabledCuDNN(bool e);
201
+ bool userEnabledMkldnn() const;
202
+ void setUserEnabledMkldnn(bool e);
203
+ bool benchmarkCuDNN() const;
204
+ void setBenchmarkCuDNN(bool);
205
+ int benchmarkLimitCuDNN() const;
206
+ void setBenchmarkLimitCuDNN(int);
207
+ bool deterministicCuDNN() const;
208
+ void setDeterministicCuDNN(bool);
209
+ bool deterministicMkldnn() const;
210
+ void setDeterministicMkldnn(bool);
211
+ bool userEnabledNNPACK() const;
212
+ void setUserEnabledNNPACK(bool e);
213
+
214
+ // Note [Disabling Fused SDP Kernels]
215
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
216
+ // Flash and Memory Efficient SDP kernels are enabled by default.
217
+ // However, they can be disabled by setting
218
+ // at::globalContext().setUserEnabledFlashSDP(false) flag.
219
+ // This is useful for debugging purposes. For example, if you want to
220
+ // compare the performance of the flash SDP kernels with the unfused
221
+ // kernel, you can disable the flash SDP kernels. By disabling
222
+ // the math SDP kernel, you can force your code to use flash kernels.
223
+ // The math SDP kernel can be disabled by setting
224
+ // at::globalContext().setUserEnabledMathSDP(false) flag.
225
+ void setSDPUseFlash(bool);
226
+ bool userEnabledFlashSDP() const;
227
+
228
+ void setSDPUseMemEfficient(bool);
229
+ bool userEnabledMemEfficientSDP() const;
230
+
231
+ void setSDPUseMath(bool);
232
+ bool userEnabledMathSDP() const;
233
+
234
+ void setSDPUseCuDNN(bool);
235
+ bool userEnabledCuDNNSDP() const;
236
+
237
+ void setAllowFP16BF16ReductionMathSDP(bool);
238
+ bool allowFP16BF16ReductionMathSDP() const;
239
+
240
+ void setSDPUseOverrideable(bool);
241
+ bool userEnabledOverrideableSDP() const;
242
+
243
+ at::LinalgBackend linalgPreferredBackend() const;
244
+ void setLinalgPreferredBackend(at::LinalgBackend);
245
+
246
+ at::BlasBackend blasPreferredBackend();
247
+ void setBlasPreferredBackend(at::BlasBackend);
248
+
249
+ // Note [Enabling Deterministic Operations]
250
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
251
+ // Operations in PyTorch that normally act nondeterministically, but have an
252
+ // alternate deterministic implementation, should satisfy the following
253
+ // requirements:
254
+ //
255
+ // * Include this comment: "See Note [Enabling Deterministic Operations]"
256
+ //
257
+ // * Check the value of `at::globalContext().deterministicAlgorithms()` to
258
+ // toggle
259
+ // between nondeterministic and deterministic implementations.
260
+ //
261
+ // * Have an entry in the list of PyTorch operations that toggle between
262
+ // nondeterministic
263
+ // and deterministic implementations, in the docstring of
264
+ // `use_deterministic_algorithms()` in torch/__init__.py
265
+ //
266
+ // `example_func()` below shows an example of toggling between
267
+ // nondeterministic and deterministic implementations:
268
+ //
269
+ // void example_func() {
270
+ // // See Note [Enabling Deterministic Operations]
271
+ // if (at::globalContext().deterministicAlgorithms()) {
272
+ // example_func_deterministic();
273
+ // } else {
274
+ // example_func_nondeterministic();
275
+ // }
276
+ // }
277
+
278
+ bool deterministicAlgorithms() const;
279
+ bool deterministicAlgorithmsWarnOnly() const;
280
+ void setDeterministicAlgorithms(bool, bool);
281
+ bool deterministicFillUninitializedMemory() const;
282
+ void setDeterministicFillUninitializedMemory(bool);
283
+
284
+ // Note [Writing Nondeterministic Operations]
285
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
286
+ // Operations in PyTorch that act nondeterministically and do not have an
287
+ // alternate deterministic implementation should satisfy the following
288
+ // requirements:
289
+ //
290
+ // * Include this comment: "See Note [Writing Nondeterministic Operations]"
291
+ //
292
+ // * Include a comment explaining why the operation is nondeterministic.
293
+ //
294
+ // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
295
+ // of the time, this should be accomplished by calling
296
+ // `at::globalContext().alertNotDeterminstic()`. However, if the
297
+ // nondeterministic behavior is caused by the CuBLAS workspace
298
+ // configuration in CUDA >= 10.2,
299
+ // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
300
+ // called instead (in this case, a comment explaining why the operation is
301
+ // nondeterministic is not necessary). See below for details on these
302
+ // methods.
303
+ //
304
+ // * Have an entry in the list of nondeterministic PyTorch operations in the
305
+ // docstring of `use_deterministic_algorithms()` in torch/__init__.py
306
+ //
307
+ // * Have a test function in `test/test_torch.py` whose name begins with
308
+ // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
309
+ // configuration is the reason for nondeterminism, the operation should be
310
+ // included in the `test_cublas_config_nondeterministic_alert` test. Any new
311
+ // tests should ideally follow a pattern similar to the existing ones.
312
+ //
313
+ // `example_func()` below shows an example of the comments and error-throwing
314
+ // code for a nondeterministic operation:
315
+ //
316
+ // void example_func() {
317
+ // // See Note [Writing Nondeterministic Operations]
318
+ // // Nondeterministic because <reason>
319
+ // at::globalContext().alertNondeterministic("example_func");
320
+ // ...
321
+ // }
322
+
323
+ // Throws an error if `Context::deterministicAlgorithms()` is true
324
+ static void alertNotDeterministic(c10::string_view const& caller);
325
+
326
+ // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
327
+ // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
328
+ // ":4096:8". For more details:
329
+ // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
330
+ void alertCuBLASConfigNotDeterministic() const;
331
+
332
+ void setFloat32MatmulPrecision(const std::string& s);
333
+ bool allowTF32CuDNN() const;
334
+ void setAllowTF32CuDNN(bool);
335
+ bool allowTF32CuBLAS() const;
336
+ void setAllowTF32CuBLAS(bool);
337
+ Float32MatmulPrecision float32MatmulPrecision() const;
338
+ void setFloat32MatmulPrecision(Float32MatmulPrecision p);
339
+ bool allowFP16ReductionCuBLAS() const;
340
+ void setAllowFP16ReductionCuBLAS(bool);
341
+ bool allowBF16ReductionCuBLAS() const;
342
+ void setAllowBF16ReductionCuBLAS(bool);
343
+ at::QEngine qEngine() const;
344
+ void setQEngine(at::QEngine e);
345
+ static const std::vector<at::QEngine>& supportedQEngines();
346
+ static bool isXNNPACKAvailable();
347
+ void setCheckSparseTensorInvariants(bool e);
348
+ bool checkSparseTensorInvariants() const;
349
+ // This method is used to release the original weight after pre-packing.
350
+ // It should be called once before loading/running the model.
351
+ // NB: By default it is set to true for mobile builds.
352
+ void setReleaseWeightsWhenPrepacking(bool e);
353
+ bool releaseWeightsWhenPrepacking() const;
354
+
355
+ void setDisplayVmapFallbackWarnings(bool enabled);
356
+ bool areVmapFallbackWarningsEnabled() const;
357
+
358
+ void setDefaultMobileCPUAllocator();
359
+ void unsetDefaultMobileCPUAllocator();
360
+ bool allowFP16ReductionCPU() const;
361
+ void setAllowFP16ReductionCPU(bool);
362
+
363
+ private:
364
+ void initCUDAIfNeeded(c10::DeviceType p) {
365
+ if (p == c10::DeviceType::CUDA) {
366
+ lazyInitCUDA();
367
+ }
368
+ }
369
+ void initHIPIfNeeded(c10::DeviceType p) {
370
+ if (p == c10::DeviceType::HIP) {
371
+ lazyInitHIP();
372
+ }
373
+ }
374
+ void initXPUIfNeeded(c10::DeviceType p) {
375
+ if (p == c10::DeviceType::XPU) {
376
+ lazyInitXPU();
377
+ }
378
+ }
379
+ static bool checkCuBLASConfigDeterministic();
380
+ c10::once_flag thc_init;
381
+ c10::once_flag thh_init;
382
+ c10::once_flag thx_init;
383
+ c10::once_flag th_mtia_init;
384
+ c10::once_flag thp_init;
385
+ bool enabled_cudnn = true;
386
+ bool deterministic_cudnn = false;
387
+ bool deterministic_mkldnn = false;
388
+ bool _deterministic_algorithms = false;
389
+ bool _deterministic_algorithms_warn_only = false;
390
+ bool _deterministic_fill_uninitialized_memory = true;
391
+ bool enabled_flashSDP = true;
392
+ bool enabled_mem_efficientSDP = true;
393
+ bool enabled_mathSDP = true;
394
+ bool enabled_cudnnSDP = true;
395
+ bool enabled_overrideable = true;
396
+ bool allow_fp16_bf16_reduction_mathSDP = false;
397
+ #ifdef USE_ROCM
398
+ bool benchmark_cudnn = true;
399
+ #else
400
+ bool benchmark_cudnn = false;
401
+ #endif
402
+ Float32MatmulPrecision float32_matmul_precision =
403
+ c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
404
+ ? at::Float32MatmulPrecision::HIGH
405
+ : at::Float32MatmulPrecision::HIGHEST;
406
+ int benchmark_limit_cudnn = 10;
407
+ bool allow_tf32_cudnn = true;
408
+ bool allow_fp16_reduction_cublas = true;
409
+ bool allow_bf16_reduction_cublas = true;
410
+ bool enabled_mkldnn = true;
411
+ bool enabled_nnpack = true;
412
+ at::LinalgBackend linalg_preferred_backend =
413
+ c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
414
+ ? at::LinalgBackend::Cusolver
415
+ : at::LinalgBackend::Default;
416
+ at::BlasBackend blas_preferred_backend =
417
+ #ifdef USE_ROCM
418
+ (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
419
+ #else
420
+ (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
421
+ #endif
422
+ ? at::BlasBackend::Cublaslt
423
+ : at::BlasBackend::Cublas;
424
+ #ifdef C10_MOBILE
425
+ bool release_original_weights = true;
426
+ #else
427
+ bool release_original_weights = false;
428
+ #endif
429
+ bool display_vmap_fallback_warnings_ = false;
430
+ std::optional<at::QEngine> quantized_engine = std::nullopt;
431
+ bool enable_sparse_tensor_invariant_checks = false;
432
+ bool allow_fp16_reduction_cpu = false;
433
+
434
+ Allocator* prev_allocator_ptr_{nullptr};
435
+ };
436
+
437
+ TORCH_API Context& globalContext();
438
+
439
+ inline void init() {
440
+ globalContext();
441
+ }
442
+
443
+ TORCH_API Allocator* getCPUAllocator();
444
+
445
+ inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
446
+ Backend p,
447
+ ScalarType s) {
448
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
449
+ p, s);
450
+ }
451
+
452
+ inline DeprecatedTypeProperties& CPU(ScalarType s) {
453
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
454
+ Backend::CPU, s);
455
+ }
456
+
457
+ inline DeprecatedTypeProperties& CUDA(ScalarType s) {
458
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
459
+ Backend::CUDA, s);
460
+ }
461
+
462
+ inline DeprecatedTypeProperties& HIP(ScalarType s) {
463
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
464
+ Backend::HIP, s);
465
+ }
466
+
467
+ inline DeprecatedTypeProperties& MPS(ScalarType s) {
468
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
469
+ Backend::MPS, s);
470
+ }
471
+
472
+ inline bool hasCUDA() {
473
+ return globalContext().hasCUDA();
474
+ }
475
+
476
+ inline bool hasMTIA() {
477
+ return globalContext().hasMTIA();
478
+ }
479
+
480
+ inline bool hasHIP() {
481
+ return globalContext().hasHIP();
482
+ }
483
+
484
+ inline bool hasIPU() {
485
+ return globalContext().hasIPU();
486
+ }
487
+
488
+ inline bool hasXLA() {
489
+ return globalContext().hasXLA();
490
+ }
491
+
492
+ inline bool hasMPS() {
493
+ return globalContext().hasMPS();
494
+ }
495
+
496
+ inline bool hasMAIA() {
497
+ return globalContext().hasMAIA();
498
+ }
499
+
500
+ inline bool hasXPU() {
501
+ return globalContext().hasXPU();
502
+ }
503
+
504
+ // Despite its name, this function returns the number of *CUDA* GPUs.
505
+ inline size_t getNumGPUs() {
506
+ // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
507
+ // FUNCTION. If you are interested in interrogating the number of
508
+ // devices for a specific device type, add that function to the
509
+ // relevant library (e.g., similar to at::cuda::device_count())
510
+ if (hasCUDA() && hasHIP()) {
511
+ throw std::runtime_error(
512
+ "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
513
+ "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
514
+ "means HIP. Rebuild PyTorch with one or the other disabled.");
515
+ } else if (hasCUDA()) {
516
+ return detail::getCUDAHooks().getNumGPUs();
517
+ } else if (hasHIP()) {
518
+ return detail::getHIPHooks().getNumGPUs();
519
+ } else {
520
+ return 0;
521
+ }
522
+ }
523
+
524
+ inline bool hasOpenMP() {
525
+ return globalContext().hasOpenMP();
526
+ }
527
+
528
+ inline bool hasMKL() {
529
+ return globalContext().hasMKL();
530
+ }
531
+
532
+ inline bool hasLAPACK() {
533
+ return globalContext().hasLAPACK();
534
+ }
535
+
536
+ inline bool hasMAGMA() {
537
+ return globalContext().hasMAGMA();
538
+ }
539
+
540
+ inline bool hasMKLDNN() {
541
+ return globalContext().hasMKLDNN();
542
+ }
543
+
544
+ inline void manual_seed(uint64_t seed) {
545
+ auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
546
+ {
547
+ // See Note [Acquire lock when using random generators]
548
+ std::lock_guard<std::mutex> lock(gen.mutex());
549
+ gen.set_current_seed(seed);
550
+ }
551
+ // NB: Sometimes we build with CUDA, but we don't have any GPUs
552
+ // available. In that case, we must not seed CUDA; it will fail!
553
+ const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
554
+ if (hasCUDA() && cuda_num_gpus > 0) {
555
+ for (const auto i : c10::irange(cuda_num_gpus)) {
556
+ auto cuda_gen = globalContext().defaultGenerator(
557
+ Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
558
+ {
559
+ // See Note [Acquire lock when using random generators]
560
+ std::lock_guard<std::mutex> lock(cuda_gen.mutex());
561
+ cuda_gen.set_current_seed(seed);
562
+ }
563
+ }
564
+ }
565
+
566
+ const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
567
+ if (hasXPU() && xpu_num_gpus) {
568
+ for (const auto i : c10::irange(xpu_num_gpus)) {
569
+ auto xpu_gen = globalContext().defaultGenerator(
570
+ Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
571
+ {
572
+ // See Note [Acquire lock when using random generators]
573
+ std::lock_guard<std::mutex> lock(xpu_gen.mutex());
574
+ xpu_gen.set_current_seed(seed);
575
+ }
576
+ }
577
+ }
578
+
579
+ if (hasMPS()) {
580
+ auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
581
+ // See Note [Acquire lock when using random generators]
582
+ std::lock_guard<std::mutex> lock(mps_gen.mutex());
583
+ mps_gen.set_current_seed(seed);
584
+ }
585
+ }
586
+
587
+ // When the global flag `allow_tf32` is set to true, cuBLAS handles are
588
+ // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
589
+ // For some operators, such as addmv, TF32 offers no performance improvement
590
+ // but causes precision loss. To help this case, this class implements
591
+ // a RAII guard that can be used to quickly disable TF32 within its scope.
592
+ //
593
+ // Usage:
594
+ // NoTF32Guard disable_tf32;
595
+ struct TORCH_API NoTF32Guard {
596
+ NoTF32Guard();
597
+ ~NoTF32Guard();
598
+ static bool should_disable_tf32();
599
+
600
+ private:
601
+ bool changed = false;
602
+ };
603
+
604
+ struct TORCH_API ROCmBackwardPassGuard {
605
+ ROCmBackwardPassGuard();
606
+ ~ROCmBackwardPassGuard();
607
+ static bool is_backward_pass();
608
+ };
609
+
610
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Device.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/core/Device.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/DeviceAccelerator.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/DeviceType.h>
4
+ #include <c10/macros/Macros.h>
5
+
6
+ #include <ATen/detail/MTIAHooksInterface.h>
7
+ #include <optional>
8
+
9
+ // This file defines the top level Accelerator concept for PyTorch.
10
+ // A device is an accelerator per the definition here if:
11
+ // - It is mutually exclusive with all other accelerators
12
+ // - It performs asynchronous compute via a Stream/Event system
13
+ // - It provides a set of common APIs as defined by AcceleratorHooksInterface
14
+ //
15
+ // As of today, accelerator devices are (in no particular order):
16
+ // CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
17
+
18
+ namespace at {
19
+
20
+ // Ensures that only one accelerator is available (at
21
+ // compile time if possible) and return it.
22
+ // When checked is true, the returned optional always has a value.
23
+ TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
24
+
25
+ TORCH_API bool isAccelerator(c10::DeviceType d);
26
+
27
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/DeprecatedTypeProperties.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/util/Metaprogramming.h>
8
+ #include <c10/util/complex.h>
9
+ #include <c10/util/string_view.h>
10
+
11
+ #ifdef __CUDACC__
12
+ #include <cuda.h> // For CUDA_VERSION
13
+ #endif
14
+
15
+ #ifdef TEMPLATE_SELECTIVE_BUILD
16
+ #include <ATen/selected_mobile_ops.h>
17
+ #else
18
+ namespace at {
19
+ /**
20
+ * The method should_include_kernel_dtype() returns true/false
21
+ * based on whether the switching code for a specific dtype should be
22
+ * included based on build time constants generated from tracing model
23
+ * execution. This method will be implemented via code-generation and
24
+ * included in this file when code-gen is ready.
25
+ */
26
+ inline constexpr bool should_include_kernel_dtype(
27
+ const char* /*kernel_tag_str*/,
28
+ at::ScalarType /*scalar_type*/
29
+ ) {
30
+ return true;
31
+ }
32
+ } // namespace at
33
+ #endif
34
+
35
+ /**
36
+ * In the Facebook internal build (using BUCK), this macro is enabled by
37
+ * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
38
+ * binary.
39
+ */
40
+ #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
41
+ namespace at {
42
+ namespace detail {
43
+ TORCH_API void record_kernel_function_dtype(std::string name);
44
+ }
45
+ } // namespace at
46
+
47
+ #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
48
+ at::detail::record_kernel_function_dtype( \
49
+ std::string(NAME) + "$" + toString(enum_type));
50
+ #else
51
+ #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
52
+ #endif
53
+
54
+ #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
55
+ do { \
56
+ if constexpr (!at::should_include_kernel_dtype( \
57
+ at_dispatch_name, enum_type)) { \
58
+ AT_ERROR( \
59
+ "dtype '", \
60
+ toString(enum_type), \
61
+ "' not selected for kernel tag ", \
62
+ at_dispatch_name); \
63
+ } \
64
+ } while (0)
65
+
66
+ #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
67
+ case enum_type: { \
68
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
69
+ using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
70
+ return __VA_ARGS__(); \
71
+ }
72
+
73
+ #define AT_DISPATCH_CASE(enum_type, ...) \
74
+ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
75
+
76
+ #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
77
+ case enum_type: { \
78
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
79
+ using scalar_t = scalar_type; \
80
+ using underlying_t C10_UNUSED = typename scalar_t::underlying; \
81
+ const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
82
+ const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
83
+ return __VA_ARGS__(); \
84
+ }
85
+
86
+ #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
87
+ enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
88
+ case enum_type: { \
89
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
90
+ using scalar_t = scalar_type; \
91
+ using underlying_t C10_UNUSED = typename scalar_t::underlying; \
92
+ const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
93
+ const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
94
+ C10_UNUSED int bit_width = bitwidth; \
95
+ C10_UNUSED int64_t quant_min = qmin; \
96
+ C10_UNUSED int64_t quant_max = qmax; \
97
+ return __VA_ARGS__(); \
98
+ }
99
+
100
+ namespace detail {
101
+
102
+ inline at::ScalarType scalar_type(at::ScalarType s) {
103
+ return s;
104
+ }
105
+
106
+ C10_DEPRECATED_MESSAGE(
107
+ "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
108
+ "pass an at::ScalarType instead")
109
+ inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
110
+ return t.scalarType();
111
+ }
112
+
113
+ C10_DEPRECATED_MESSAGE(
114
+ "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
115
+ "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
116
+ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
117
+
118
+ C10_DEPRECATED_MESSAGE(
119
+ "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
120
+ "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
121
+ "instead")
122
+ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
123
+
124
+ } // namespace detail
125
+
126
+ // The AT_DISPATCH_* family of macros provides the ability to
127
+ // conveniently generate specializations of a kernel over all of the
128
+ // dtypes we care about in PyTorch. We call it "dispatch" because
129
+ // we are "dispatching" to the correct, dtype-specific kernel.
130
+ //
131
+ // A standard usage looks like:
132
+ //
133
+ // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
134
+ // // Your code here, with 'scalar_t' now defined to
135
+ // // be the dtype in question
136
+ // });
137
+ //
138
+ // There are many variations of this macro, so it's important to
139
+ // understand exactly /which/ dtypes you want to get instantiated, as
140
+ // well as what the "default" set is.
141
+ //
142
+ // The default set of dtypes that are instantiated (e.g., by
143
+ // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
144
+ // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
145
+ // but NOT booleans (bool), half-precision floats (Half) or
146
+ // complex number (c10::complex<float>, c10::complex<double>).
147
+ // This "cut" is somewhat historical (the default types are the
148
+ // ones that TH historically supported), but it also reflects the
149
+ // fact that the non-default types are "poorly" behaved (booleans
150
+ // are NOT integers mod 2, half precision operations ~essentially
151
+ // don't exist on CPU, complex numbers are an experimental application).
152
+ //
153
+ // Here are the questions you should generally ask to decide which
154
+ // dispatch you want:
155
+ //
156
+ // 1. Is this an integral or floating point specific operation?
157
+ // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
158
+ //
159
+ // 2. Should half be supported? (If you're on CPU, the answer is almost
160
+ // definitely no. If you do want support, use one of the AND_HALF
161
+ // macros)
162
+ //
163
+ // Much rarer situations:
164
+ //
165
+ // 3. Should bool be supported? (You often have to write your kernel
166
+ // differently if arithmetic operations are involved.) If so,
167
+ // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
168
+ //
169
+ // 4. Should complex be supported? The answer is almost always no,
170
+ // unless you are working on "generic" code that should work on
171
+ // all dtypes.
172
+ //
173
+ // Parameters:
174
+ // -----------
175
+ //
176
+ // 1. The NAME argument is a "tag" that is used to trace and then
177
+ // conditionally compile fragments of the case statements such
178
+ // that the kernel functions are specialized only for the dtypes
179
+ // that are needed. The NAME parameter *must* be a build time
180
+ // const char* (can't be std::string, etc...)
181
+ //
182
+ // Please ensure that the NAME is unique for every implementation
183
+ // or you run the risk of over-including code for the kernel
184
+ // functions. There is no risk of missing out on any code, so
185
+ // it's mostly a risk of a Type-2 error, and not a Type-1 error.
186
+ //
187
+ // Switch-like syntax:
188
+ // -------------------
189
+ // There is also a switch-case like syntax which is useful if a kernel
190
+ // needs to be specialized for particular scalar types
191
+ //
192
+ // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
193
+ // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
194
+ // op_integral<scalar_t>(iter);
195
+ // })
196
+ // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
197
+ // op_floating<scalar_t>(iter);
198
+ // })
199
+ // AT_DISPATCH_CASE(kBool, [&] {
200
+ // op_bool(iter);
201
+ // })
202
+ // );
203
+ //
204
+ // For each AT_DISPATCH_FOO macro, there is a corresponding
205
+ // AT_DISPATCH_CASE_FOO macro which can be used inside of an
206
+ // AT_DISPATCH_SWITCH block.
207
+
208
+ // NB: the the_type variable is not used, but we have kept it for
209
+ // backwards compatibility. It's probably not used by anyone though;
210
+ // but we're just being safe (and it doesn't hurt.) Note we must
211
+ // use it to shut up warnings about unused store.
212
+
213
+ #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
214
+ [&] { \
215
+ const auto& the_type = TYPE; \
216
+ constexpr const char* at_dispatch_name = NAME; \
217
+ /* don't use TYPE again in case it is an expensive or side-effect op */ \
218
+ at::ScalarType _st = ::detail::scalar_type(the_type); \
219
+ RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
220
+ switch (_st) { \
221
+ __VA_ARGS__ \
222
+ default: \
223
+ AT_ERROR( \
224
+ '"', \
225
+ at_dispatch_name, \
226
+ "\" not implemented for '", \
227
+ toString(_st), \
228
+ "'"); \
229
+ } \
230
+ }()
231
+
232
+ #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
233
+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
234
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
235
+
236
+ #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
237
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
238
+
239
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
240
+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
241
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
242
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
243
+
244
+ #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
245
+ AT_DISPATCH_SWITCH( \
246
+ TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
247
+
248
+ #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
249
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
250
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
251
+
252
+ #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
253
+ AT_DISPATCH_SWITCH( \
254
+ TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
255
+
256
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
257
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
258
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
259
+
260
+ #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
261
+ AT_DISPATCH_SWITCH( \
262
+ TYPE, \
263
+ NAME, \
264
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
265
+
266
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
267
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
268
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
269
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
270
+
271
+ #define AT_DISPATCH_FLOATING_TYPES_AND2( \
272
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
273
+ AT_DISPATCH_SWITCH( \
274
+ TYPE, \
275
+ NAME, \
276
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
277
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
278
+
279
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
280
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
281
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
282
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
283
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
284
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
285
+
286
+ #define AT_DISPATCH_FLOATING_TYPES_AND3( \
287
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
288
+ AT_DISPATCH_SWITCH( \
289
+ TYPE, \
290
+ NAME, \
291
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
292
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
293
+
294
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
295
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
296
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
297
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
298
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
299
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
300
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
301
+
302
+ #define AT_DISPATCH_FLOATING_TYPES_AND4( \
303
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
304
+ AT_DISPATCH_SWITCH( \
305
+ TYPE, \
306
+ NAME, \
307
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
308
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
309
+
310
+ #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
311
+ AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
312
+ AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
313
+
314
+ #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
315
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
316
+
317
+ #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
318
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
319
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
320
+
321
+ #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
322
+ AT_DISPATCH_SWITCH( \
323
+ TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
324
+
325
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
326
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
327
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
328
+
329
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
330
+ AT_DISPATCH_SWITCH( \
331
+ TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
332
+
333
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
334
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
335
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
336
+
337
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
338
+ SCALARTYPE, TYPE, NAME, ...) \
339
+ AT_DISPATCH_SWITCH( \
340
+ TYPE, \
341
+ NAME, \
342
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
343
+ SCALARTYPE, __VA_ARGS__))
344
+
345
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
346
+ SCALARTYPE1, SCALARTYPE2, ...) \
347
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
348
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
349
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
350
+
351
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
352
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
353
+ AT_DISPATCH_SWITCH( \
354
+ TYPE, \
355
+ NAME, \
356
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
357
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
358
+
359
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
360
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
361
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
362
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
363
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
364
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
365
+
366
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
367
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
368
+ AT_DISPATCH_SWITCH( \
369
+ TYPE, \
370
+ NAME, \
371
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
372
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373
+
374
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
375
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
377
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
378
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
379
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
380
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381
+
382
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
383
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384
+ AT_DISPATCH_SWITCH( \
385
+ TYPE, \
386
+ NAME, \
387
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
388
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389
+
390
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
391
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
392
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
393
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
394
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
395
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
396
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
397
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
398
+
399
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
400
+ SCALARTYPE1, \
401
+ SCALARTYPE2, \
402
+ SCALARTYPE3, \
403
+ SCALARTYPE4, \
404
+ SCALARTYPE5, \
405
+ TYPE, \
406
+ NAME, \
407
+ ...) \
408
+ AT_DISPATCH_SWITCH( \
409
+ TYPE, \
410
+ NAME, \
411
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
412
+ SCALARTYPE1, \
413
+ SCALARTYPE2, \
414
+ SCALARTYPE3, \
415
+ SCALARTYPE4, \
416
+ SCALARTYPE5, \
417
+ __VA_ARGS__))
418
+
419
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
420
+ SCALARTYPE1, \
421
+ SCALARTYPE2, \
422
+ SCALARTYPE3, \
423
+ SCALARTYPE4, \
424
+ SCALARTYPE5, \
425
+ SCALARTYPE6, \
426
+ ...) \
427
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
428
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
429
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
430
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
431
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
432
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
433
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
434
+
435
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
436
+ SCALARTYPE1, \
437
+ SCALARTYPE2, \
438
+ SCALARTYPE3, \
439
+ SCALARTYPE4, \
440
+ SCALARTYPE5, \
441
+ SCALARTYPE6, \
442
+ TYPE, \
443
+ NAME, \
444
+ ...) \
445
+ AT_DISPATCH_SWITCH( \
446
+ TYPE, \
447
+ NAME, \
448
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
449
+ SCALARTYPE1, \
450
+ SCALARTYPE2, \
451
+ SCALARTYPE3, \
452
+ SCALARTYPE4, \
453
+ SCALARTYPE5, \
454
+ SCALARTYPE6, \
455
+ __VA_ARGS__))
456
+
457
+ #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
458
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
459
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
460
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
461
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
462
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
463
+
464
+ #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
465
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
466
+
467
+ #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
468
+ AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
469
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
470
+
471
+ #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
472
+ AT_DISPATCH_SWITCH( \
473
+ TYPE, \
474
+ NAME, \
475
+ AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
476
+
477
+ #define AT_DISPATCH_CASE_ALL_TYPES(...) \
478
+ AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
479
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
480
+
481
+ #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
482
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
483
+
484
+ #define AT_DISPATCH_CASE_QINT_TYPES(...) \
485
+ AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
486
+ AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
487
+ AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
488
+
489
+ #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
490
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
491
+
492
+ #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
493
+ AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
494
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
495
+
496
+ #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
497
+ AT_DISPATCH_SWITCH( \
498
+ TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
499
+
500
+ #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
501
+ AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
502
+ AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
503
+
504
+ #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
505
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
506
+
507
+ #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
508
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
509
+ at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
510
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
511
+ at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
512
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
513
+ at::kQInt32, \
514
+ at::qint32, \
515
+ CHAR_BIT * sizeof(int), \
516
+ INT_MIN, \
517
+ INT_MAX, \
518
+ __VA_ARGS__) \
519
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
520
+ at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
521
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
522
+ at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
523
+
524
+ #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
525
+ AT_DISPATCH_SWITCH( \
526
+ TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
527
+
528
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
529
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
530
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
531
+
532
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
533
+ AT_DISPATCH_SWITCH( \
534
+ TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
535
+
536
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
537
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
538
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
539
+
540
+ #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
541
+ AT_DISPATCH_SWITCH( \
542
+ TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
543
+
544
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
545
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
546
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
547
+
548
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
549
+ AT_DISPATCH_SWITCH( \
550
+ TYPE, \
551
+ NAME, \
552
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
553
+
554
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
555
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
556
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
557
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
558
+
559
+ #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
560
+ AT_DISPATCH_SWITCH( \
561
+ TYPE, \
562
+ NAME, \
563
+ AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
564
+
565
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
566
+ SCALARTYPE1, SCALARTYPE2, ...) \
567
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
568
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
569
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
570
+
571
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
572
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
573
+ AT_DISPATCH_SWITCH( \
574
+ TYPE, \
575
+ NAME, \
576
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
577
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
578
+
579
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
580
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
581
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
582
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
583
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
584
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
585
+
586
+ #define AT_DISPATCH_ALL_TYPES_AND3( \
587
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
588
+ AT_DISPATCH_SWITCH( \
589
+ TYPE, \
590
+ NAME, \
591
+ AT_DISPATCH_CASE_ALL_TYPES_AND3( \
592
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
593
+
594
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
595
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
596
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
597
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
598
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
599
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
600
+
601
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
602
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
603
+ AT_DISPATCH_SWITCH( \
604
+ TYPE, \
605
+ NAME, \
606
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
607
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
608
+
609
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
610
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
611
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
612
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
613
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
614
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
615
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
616
+
617
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
618
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
619
+ AT_DISPATCH_SWITCH( \
620
+ TYPE, \
621
+ NAME, \
622
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
623
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
624
+
625
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
626
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
627
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
628
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
629
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
630
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
631
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
632
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
633
+
634
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
635
+ SCALARTYPE1, \
636
+ SCALARTYPE2, \
637
+ SCALARTYPE3, \
638
+ SCALARTYPE4, \
639
+ SCALARTYPE5, \
640
+ TYPE, \
641
+ NAME, \
642
+ ...) \
643
+ AT_DISPATCH_SWITCH( \
644
+ TYPE, \
645
+ NAME, \
646
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
647
+ SCALARTYPE1, \
648
+ SCALARTYPE2, \
649
+ SCALARTYPE3, \
650
+ SCALARTYPE4, \
651
+ SCALARTYPE5, \
652
+ __VA_ARGS__))
653
+
654
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
655
+ SCALARTYPE1, \
656
+ SCALARTYPE2, \
657
+ SCALARTYPE3, \
658
+ SCALARTYPE4, \
659
+ SCALARTYPE5, \
660
+ SCALARTYPE6, \
661
+ ...) \
662
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
663
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
664
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
665
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
666
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
667
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
668
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
669
+
670
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
671
+ SCALARTYPE1, \
672
+ SCALARTYPE2, \
673
+ SCALARTYPE3, \
674
+ SCALARTYPE4, \
675
+ SCALARTYPE5, \
676
+ SCALARTYPE6, \
677
+ TYPE, \
678
+ NAME, \
679
+ ...) \
680
+ AT_DISPATCH_SWITCH( \
681
+ TYPE, \
682
+ NAME, \
683
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
684
+ SCALARTYPE1, \
685
+ SCALARTYPE2, \
686
+ SCALARTYPE3, \
687
+ SCALARTYPE4, \
688
+ SCALARTYPE5, \
689
+ SCALARTYPE6, \
690
+ __VA_ARGS__))
691
+
692
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
693
+ SCALARTYPE1, \
694
+ SCALARTYPE2, \
695
+ SCALARTYPE3, \
696
+ SCALARTYPE4, \
697
+ SCALARTYPE5, \
698
+ SCALARTYPE6, \
699
+ SCALARTYPE7, \
700
+ ...) \
701
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
702
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
703
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
704
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
705
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
706
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
707
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
708
+ AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
709
+
710
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
711
+ SCALARTYPE1, \
712
+ SCALARTYPE2, \
713
+ SCALARTYPE3, \
714
+ SCALARTYPE4, \
715
+ SCALARTYPE5, \
716
+ SCALARTYPE6, \
717
+ SCALARTYPE7, \
718
+ TYPE, \
719
+ NAME, \
720
+ ...) \
721
+ AT_DISPATCH_SWITCH( \
722
+ TYPE, \
723
+ NAME, \
724
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
725
+ SCALARTYPE1, \
726
+ SCALARTYPE2, \
727
+ SCALARTYPE3, \
728
+ SCALARTYPE4, \
729
+ SCALARTYPE5, \
730
+ SCALARTYPE6, \
731
+ SCALARTYPE7, \
732
+ __VA_ARGS__))
733
+
734
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
735
+ SCALARTYPE1, \
736
+ SCALARTYPE2, \
737
+ SCALARTYPE3, \
738
+ SCALARTYPE4, \
739
+ SCALARTYPE5, \
740
+ SCALARTYPE6, \
741
+ SCALARTYPE7, \
742
+ SCALARTYPE8, \
743
+ ...) \
744
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
745
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
746
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
747
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
748
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
749
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
750
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
751
+ AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
752
+ AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
753
+
754
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
755
+ SCALARTYPE1, \
756
+ SCALARTYPE2, \
757
+ SCALARTYPE3, \
758
+ SCALARTYPE4, \
759
+ SCALARTYPE5, \
760
+ SCALARTYPE6, \
761
+ SCALARTYPE7, \
762
+ SCALARTYPE8, \
763
+ TYPE, \
764
+ NAME, \
765
+ ...) \
766
+ AT_DISPATCH_SWITCH( \
767
+ TYPE, \
768
+ NAME, \
769
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
770
+ SCALARTYPE1, \
771
+ SCALARTYPE2, \
772
+ SCALARTYPE3, \
773
+ SCALARTYPE4, \
774
+ SCALARTYPE5, \
775
+ SCALARTYPE6, \
776
+ SCALARTYPE7, \
777
+ SCALARTYPE8, \
778
+ __VA_ARGS__))
779
+
780
+ #define AT_DISPATCH_CASE_BIT_TYPES(...) \
781
+ AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
782
+ AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
783
+ AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
784
+ AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
785
+ AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
786
+
787
+ #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
788
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
789
+
790
+ #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
791
+ AT_DISPATCH_SWITCH( \
792
+ TYPE, \
793
+ NAME, \
794
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
795
+ at::ScalarType::Int, index_t, __VA_ARGS__) \
796
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
797
+ at::ScalarType::Long, index_t, __VA_ARGS__))
798
+
799
+ // ----------------------------------------------------------------------------
800
+ // DEPRECATED MACROS, DON'T USE THESE
801
+ // ----------------------------------------------------------------------------
802
+
803
+ #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
804
+ detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
805
+ AT_DISPATCH_SWITCH( \
806
+ TYPE, \
807
+ NAME, \
808
+ AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TensorBase.h>
3
+
4
+ namespace at::detail {
5
+
6
+ inline void check_size_nonnegative(ArrayRef<int64_t> size) {
7
+ for (const auto& x : size) {
8
+ TORCH_CHECK(
9
+ x >= 0,
10
+ "Trying to create tensor with negative dimension ",
11
+ x,
12
+ ": ",
13
+ size);
14
+ }
15
+ }
16
+
17
+ inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
18
+ for (const auto& x : size) {
19
+ TORCH_CHECK(
20
+ x.expect_size(__FILE__, __LINE__),
21
+ "Trying to create tensor with negative dimension ",
22
+ x,
23
+ ": ",
24
+ size);
25
+ }
26
+ }
27
+
28
+ TORCH_API size_t computeStorageNbytesContiguous(
29
+ IntArrayRef sizes,
30
+ size_t itemsize,
31
+ size_t storage_offset = 0);
32
+ TORCH_API SymInt computeStorageNbytesContiguous(
33
+ SymIntArrayRef sizes,
34
+ const SymInt& itemsize,
35
+ const SymInt& storage_offset = 0);
36
+ TORCH_API size_t computeStorageNbytes(
37
+ IntArrayRef sizes,
38
+ IntArrayRef strides,
39
+ size_t itemsize,
40
+ size_t storage_offset = 0);
41
+ TORCH_API SymInt computeStorageNbytes(
42
+ SymIntArrayRef sizes,
43
+ SymIntArrayRef strides,
44
+ const SymInt& itemsize,
45
+ const SymInt& storage_offset = 0);
46
+
47
+ TORCH_API TensorBase empty_generic(
48
+ IntArrayRef size,
49
+ c10::Allocator* allocator,
50
+ c10::DispatchKeySet ks,
51
+ ScalarType scalar_type,
52
+ std::optional<c10::MemoryFormat> memory_format_opt);
53
+
54
+ TORCH_API TensorBase empty_generic_symint(
55
+ SymIntArrayRef size,
56
+ c10::Allocator* allocator,
57
+ c10::DispatchKeySet ks,
58
+ ScalarType scalar_type,
59
+ std::optional<c10::MemoryFormat> memory_format_opt);
60
+
61
+ TORCH_API TensorBase empty_strided_generic(
62
+ IntArrayRef size,
63
+ IntArrayRef stride,
64
+ c10::Allocator* allocator,
65
+ c10::DispatchKeySet ks,
66
+ ScalarType scalar_type);
67
+
68
+ TORCH_API TensorBase empty_strided_symint_generic(
69
+ SymIntArrayRef size,
70
+ SymIntArrayRef stride,
71
+ c10::Allocator* allocator,
72
+ c10::DispatchKeySet ks,
73
+ ScalarType scalar_type);
74
+
75
+ TORCH_API TensorBase empty_cpu(
76
+ IntArrayRef size,
77
+ ScalarType dtype,
78
+ bool pin_memory = false,
79
+ std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
80
+
81
+ TORCH_API TensorBase empty_cpu(
82
+ IntArrayRef size,
83
+ std::optional<ScalarType> dtype_opt,
84
+ std::optional<Layout> layout_opt,
85
+ std::optional<Device> device_opt,
86
+ std::optional<bool> pin_memory_opt,
87
+ std::optional<c10::MemoryFormat> memory_format_opt);
88
+
89
+ TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
90
+
91
+ TORCH_API TensorBase empty_strided_cpu(
92
+ IntArrayRef size,
93
+ IntArrayRef stride,
94
+ ScalarType dtype,
95
+ bool pin_memory = false);
96
+
97
+ TORCH_API TensorBase empty_strided_cpu(
98
+ IntArrayRef size,
99
+ IntArrayRef stride,
100
+ std::optional<ScalarType> dtype_opt,
101
+ std::optional<Layout> layout_opt,
102
+ std::optional<Device> device_opt,
103
+ std::optional<bool> pin_memory_opt);
104
+
105
+ TORCH_API TensorBase empty_strided_cpu(
106
+ IntArrayRef size,
107
+ IntArrayRef stride,
108
+ const TensorOptions& options);
109
+
110
+ TORCH_API TensorBase empty_meta(
111
+ IntArrayRef size,
112
+ ScalarType dtype,
113
+ std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
114
+
115
+ TORCH_API TensorBase empty_meta(
116
+ IntArrayRef size,
117
+ std::optional<ScalarType> dtype_opt,
118
+ std::optional<Layout> layout_opt,
119
+ std::optional<Device> device_opt,
120
+ std::optional<bool> pin_memory_opt,
121
+ std::optional<c10::MemoryFormat> memory_format_opt);
122
+
123
+ TORCH_API TensorBase empty_symint_meta(
124
+ SymIntArrayRef size,
125
+ std::optional<ScalarType> dtype_opt,
126
+ std::optional<Layout> layout_opt,
127
+ std::optional<Device> device_opt,
128
+ std::optional<bool> pin_memory_opt,
129
+ std::optional<c10::MemoryFormat> memory_format_opt);
130
+
131
+ TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
132
+
133
+ TORCH_API TensorBase
134
+ empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
135
+
136
+ TORCH_API TensorBase empty_strided_meta(
137
+ IntArrayRef size,
138
+ IntArrayRef stride,
139
+ std::optional<ScalarType> dtype_opt,
140
+ std::optional<Layout> layout_opt,
141
+ std::optional<Device> device_opt,
142
+ std::optional<bool> pin_memory_opt);
143
+
144
+ TORCH_API TensorBase empty_strided_meta(
145
+ IntArrayRef size,
146
+ IntArrayRef stride,
147
+ const TensorOptions& options);
148
+
149
+ TORCH_API TensorBase empty_strided_symint_meta(
150
+ SymIntArrayRef size,
151
+ SymIntArrayRef stride,
152
+ ScalarType dtype);
153
+
154
+ TORCH_API TensorBase empty_strided_symint_meta(
155
+ SymIntArrayRef size,
156
+ SymIntArrayRef stride,
157
+ std::optional<ScalarType> dtype_opt,
158
+ std::optional<Layout> layout_opt,
159
+ std::optional<Device> device_opt);
160
+
161
+ TORCH_API TensorBase empty_strided_symint_meta(
162
+ SymIntArrayRef size,
163
+ SymIntArrayRef stride,
164
+ const TensorOptions& options);
165
+
166
+ } // namespace at::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBase.h>
2
+
3
+ // Broadcasting utilities for working with TensorBase
4
+ namespace at {
5
+ namespace internal {
6
+ TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
7
+ } // namespace internal
8
+
9
+ inline c10::MaybeOwned<TensorBase> expand_size(
10
+ const TensorBase& self,
11
+ IntArrayRef size) {
12
+ if (size.equals(self.sizes())) {
13
+ return c10::MaybeOwned<TensorBase>::borrowed(self);
14
+ }
15
+ return c10::MaybeOwned<TensorBase>::owned(
16
+ at::internal::expand_slow_path(self, size));
17
+ }
18
+ c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
19
+ delete;
20
+
21
+ inline c10::MaybeOwned<TensorBase> expand_inplace(
22
+ const TensorBase& tensor,
23
+ const TensorBase& to_expand) {
24
+ return expand_size(to_expand, tensor.sizes());
25
+ }
26
+ c10::MaybeOwned<TensorBase> expand_inplace(
27
+ const TensorBase& tensor,
28
+ TensorBase&& to_expand) = delete;
29
+
30
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/FuncTorchTLS.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Macros.h>
4
+ #include <memory>
5
+
6
+ namespace at::functorch {
7
+
8
+ // NOTE [functorch TLS in pytorch/pytorch]
9
+ //
10
+ // functorch lives out-of-tree. However, it has some TLS that needs to be
11
+ // propagated. The solution for that is we store a pointer to the TLS
12
+ // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
13
+ // include whatever functorch needs.
14
+ //
15
+ // We need to store a pointer due to the indirection:
16
+ // inside functorch, we will create a subclass of FunctorchTLSBase called
17
+ // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
18
+ // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
19
+ // yet.
20
+ //
21
+ // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
22
+ // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
23
+ // We can't directly pass around FunctorchTLSBase (without a pointer) because
24
+ // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
25
+ // more elements.
26
+ struct TORCH_API FuncTorchTLSBase {
27
+ virtual ~FuncTorchTLSBase() = default;
28
+ virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
29
+
30
+ virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
31
+ virtual void checkSupportsCppAutogradFunction() const = 0;
32
+ virtual void checkSupportsInplaceRequiresGrad() const = 0;
33
+ virtual void checkSupportsRetainGrad() const = 0;
34
+ };
35
+
36
+ // returns deepcopy of the functorch tls
37
+ TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
38
+
39
+ // sets the functorch tls. always does a deep copy.
40
+ TORCH_API void setFuncTorchTLS(
41
+ const std::shared_ptr<const FuncTorchTLSBase>& state);
42
+
43
+ // get a mutable reference to the functorch tls
44
+ TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
45
+
46
+ } // namespace at::functorch
.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+
5
+ #include <utility>
6
+
7
+ namespace at::functionalization {
8
+
9
+ // See Note [Functionalization Pass In Core]
10
+
11
+ // ViewMeta is a class used by the functionalization pass to navigate between
12
+ // a base tensor and a view tensor.
13
+ // For example, if I call `b = a.view1(...)`
14
+ // the functionalization pass will generate and store a ViewMeta on b that looks
15
+ // like:
16
+ //
17
+ // ViewMeta(
18
+ // [<captures>](const Tensor& base, int64_t mutated_view_idx) {
19
+ // return base.view1(...);
20
+ // },
21
+ // [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
22
+ // int64_t mutated_view_idx) -> at::Tensor {
23
+ // return at::functionalization::impl::view1_inverse(base, mutated_view,
24
+ // ...);
25
+ // }
26
+ //
27
+ // The forward_fn lambda describes how to replay view1 on a tensor.
28
+ //
29
+ // The reverse_fn lambda describes how, given a tensor that is already a view,
30
+ // how to get the corresponding base tensor. See Note [Functionalization Pass:
31
+ // View Inverses] for details.
32
+ struct ViewMeta {
33
+ ViewMeta(
34
+ std::function<Tensor(const Tensor&, int64_t)> forward,
35
+ std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
36
+ bool has_symbolic_inputs,
37
+ bool is_multi_output = false,
38
+ bool is_as_strided = false,
39
+ int64_t out_idx = 0)
40
+ : forward_fn(std::move(forward)),
41
+ reverse_fn(std::move(reverse)),
42
+ out_index(out_idx),
43
+ is_multi_output(is_multi_output),
44
+ is_as_strided(is_as_strided),
45
+ has_symbolic_inputs(has_symbolic_inputs) {}
46
+
47
+ std::function<Tensor(const Tensor&, int64_t)> forward_fn;
48
+ std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
49
+ // See Note [out_idx in ViewMeta]
50
+ int64_t out_index;
51
+
52
+ // Tells us if this is a multi-output view
53
+ bool is_multi_output;
54
+
55
+ bool is_as_strided;
56
+
57
+ // Tells us if this view operation has any symbolic inputs
58
+ bool has_symbolic_inputs;
59
+
60
+ // Returns a copy of the current ViewMeta, if out_idx matches the current
61
+ // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
62
+ // functions, but a new out index.
63
+ ViewMeta to_out_idx(int64_t out_idx);
64
+ };
65
+
66
+ // FunctionalStorageImpl is a subclass of StorageImpl used by the
67
+ // functionalization pass. It has no underlying data (similar to meta storage).
68
+ // It also knows how to reflect mutations to tensors in the absence of a valid
69
+ // data pointer.
70
+ //
71
+ // A storage represents the state shared by (potentially multiple) views of the
72
+ // same tensor. For example, in the following code:
73
+ //
74
+ // b = a.view1(...)
75
+ // c = b.view2(...)
76
+ // b.add_(1)
77
+ // --> storage.add_update(b, {view1_meta})
78
+ //
79
+ // The call to add_(1) will result in a call to alias.add_update(b,
80
+ // {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
81
+ // c is used in an expression (e.g. you try to print c, or pass it to an
82
+ // operator). Doing so will involve "syncing" c. First we apply any pending
83
+ // updates to the alias, and then we regenerate c by replaying its views off of
84
+ // the updated alias. E.g:
85
+ //
86
+ // print(str(c))
87
+ // --> c.sync_()
88
+ // --> alias.apply_updates() // after this, the alias will be updated to
89
+ // reflect the mutation to b
90
+ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
91
+ public:
92
+ struct Update {
93
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
94
+ const at::Tensor new_val;
95
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
96
+ const std::vector<ViewMeta> view_metas;
97
+ };
98
+
99
+ explicit FunctionalStorageImpl(const Tensor& value);
100
+
101
+ void add_update(
102
+ const Tensor& updated_val,
103
+ const std::vector<ViewMeta>& view_metas);
104
+ bool apply_updates();
105
+ const Tensor& base() {
106
+ return base_;
107
+ }
108
+ size_t generation() const {
109
+ return generation_;
110
+ }
111
+ void freeze() {
112
+ frozen_ = true;
113
+ }
114
+
115
+ c10::SymInt get_storage_size(bool before) {
116
+ if (before) {
117
+ return original_storage_size_;
118
+ } else {
119
+ return curr_storage_size_;
120
+ }
121
+ }
122
+
123
+ ~FunctionalStorageImpl() override = default;
124
+
125
+ void mark_mutation() {
126
+ mutation_counter_++;
127
+ }
128
+ void mark_mutation_during_no_grad_or_inference_mode() {
129
+ mutation_counter_during_no_grad_or_inference_mode_++;
130
+ }
131
+ void mark_mutation_hidden_from_autograd() {
132
+ mutation_counter_hidden_from_autograd_++;
133
+ }
134
+
135
+ bool are_all_mutations_under_no_grad_or_inference_mode() const {
136
+ auto non_autograd_mutations =
137
+ mutation_counter_during_no_grad_or_inference_mode_ +
138
+ mutation_counter_hidden_from_autograd_;
139
+ // The <= is because both counters will technically be incremented, if we
140
+ // perform e.g. a triton kernel mutation under no_grad
141
+ return mutation_counter_ <= non_autograd_mutations;
142
+ }
143
+
144
+ bool are_all_mutations_hidden_from_autograd() const {
145
+ // mutations under no_grad / inference_mode are technically not hidden from
146
+ // autograd - they change the version counter
147
+ return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
148
+ }
149
+
150
+ void mark_inductor_storage_resize(c10::SymInt new_size) {
151
+ inductor_storage_resized_ = true;
152
+ curr_storage_size_ = std::move(new_size);
153
+ }
154
+
155
+ bool was_inductor_storage_resized() {
156
+ return inductor_storage_resized_;
157
+ }
158
+
159
+ private:
160
+ // NB: base_ should always point to a tensor BELOW the current
161
+ // functionalization layer. This is mainly to avoid reference cycles. e.g.
162
+ // given `b = a.view(...)` Both a.storage_ and b.storage_ are a
163
+ // FunctionStorageImpl containing an Walualias, with contains a Tensor
164
+ // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
165
+ // should point not to a, but to a's unwrapped value, a.value_` See Note
166
+ // [Functionalization: Walualias Removal] for a diagram that shows this
167
+ // visually.
168
+ at::Tensor base_;
169
+ std::vector<Update> updates_;
170
+ // generation_ gets incremented every time a mutation is queued onto the
171
+ // alias. It is used to determine if a given tensor is "up to date", or if it
172
+ // needs to be regenerated from the alias.
173
+ size_t generation_ = 0;
174
+ // If frozen, no more mutations are allowed on this storage. Once frozen, a
175
+ // storage cannot be unfrozen.
176
+ bool frozen_ = false;
177
+
178
+ // These mutation counters are bumped on the storage
179
+ // whenever a FunctionalTensorWrapper experiences a mutation.
180
+ // When the mutation is under no_grad, or comes from a triton kernel, we also
181
+ // bump the corresponding during_no_grad or hidden_from_autograd counters. Why
182
+ // do we need to detect these two situations separately from "normal" input
183
+ // mutations? (1) "normal" input mutations can mutate autograd metadata like
184
+ // .grad_fn,
185
+ // in which case they need to be replayed outside of the compiled graph
186
+ // (2) "no_grad" input mutations are generally safe to keep in the graph (and
187
+ // compile),
188
+ // but they bump the tensor's VC, so we need to mark_dirty() on the inputs
189
+ // in torch.compile
190
+ // (3) mutations that are fully hidden from autograd (e.g. from a triton
191
+ // kernel)
192
+ // do not mutate any autograd state, and be fully kept in the graph
193
+ // When we detect that an input was mutated, we need to be able to tell if:
194
+ // (1) all of the mutations were from triton kernels
195
+ // (2) all of the mutations were under no_grad
196
+ uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
197
+ uint64_t mutation_counter_ = 0;
198
+ uint64_t mutation_counter_hidden_from_autograd_ = 0;
199
+
200
+ // Used to tell if:
201
+ // (1) There were any storage resizes on a graph input
202
+ // (2) The original/curr storage size tell us if these resizes result in a nop
203
+ bool inductor_storage_resized_ = false;
204
+ c10::SymInt original_storage_size_;
205
+ c10::SymInt curr_storage_size_;
206
+ };
207
+
208
+ } // namespace at::functionalization
.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #pragma once
3
+
4
+ #include <ATen/ArrayRef.h>
5
+ #include <ATen/FunctionalStorageImpl.h>
6
+ #include <ATen/core/IListRef.h>
7
+ #include <ATen/core/List.h>
8
+ #include <ATen/core/boxing/BoxedKernel.h>
9
+ #include <ATen/core/boxing/impl/boxing.h>
10
+ #include <ATen/core/dispatch/Dispatcher.h>
11
+
12
+ #include <c10/core/DispatchKey.h>
13
+
14
+ namespace at {
15
+
16
+ // Note [Functionalization Pass In Core]
17
+ // The Functionalization pass is used to remove aliasing from a pytorch program.
18
+ //
19
+ // This is useful for backends that don't support aliasing, like XLA and Vulkan.
20
+ // It's also necessary in order to remove mutation from a program, which is
21
+ // needed in Functorch.
22
+ //
23
+ // Consider this program:
24
+ // a = torch.ones(...)
25
+ // b = a.view(...)
26
+ // b.add_(1)
27
+ //
28
+ // In this program, b is meant to alias with a due to the use of view(). At the
29
+ // end of the program, both a and b are full of 2's. However, backends that
30
+ // don't support aliasing aren't able to correctly implement the view()
31
+ // operator. Instead, they can opt into the Functionalization pass, which will
32
+ // sit between the user and the backend, and provide the necessary aliasing
33
+ // logic.
34
+ //
35
+ // The functionalization pass will turn the above program into a slightly
36
+ // different program that has the same semantics, transparently to the user,
37
+ // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
38
+ // a.view_copy(...) # view() replaced with view_copy(). Backends like
39
+ // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
40
+ // pass machinery knows that a and b are aliased - it applies b's mutation to a
41
+ // too.
42
+ //
43
+ // So, how does the functionalization pass keep track of which tensors are
44
+ // aliased? The pass works by wrapping EVERY tensor in the program inside of a
45
+ // FunctionalTensorWrapper, which knows about its alias'd tensors.
46
+ //
47
+ // See Note [Functionalization: Alias Removal] for details on the aliasing
48
+ // machinery. See Note [Functionalization: Mutation Removal] for details on
49
+ // mutation removal.
50
+ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
51
+ explicit FunctionalTensorWrapper(const Tensor& value);
52
+ // Additional constructor to create a FunctionalTensorWrapper directly from an
53
+ // underlying tensor that was created from a view. For example, the code b =
54
+ // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
55
+ // view1_meta)
56
+ explicit FunctionalTensorWrapper(
57
+ const Tensor& view_value,
58
+ const FunctionalTensorWrapper* base,
59
+ const functionalization::ViewMeta& meta);
60
+
61
+ // Get the underlying, actual tensor, that doesn't know anything about
62
+ // functionalization.
63
+ const Tensor& value() const {
64
+ return value_;
65
+ };
66
+ // The concept of "level" is only ever important to functorch; it's exposed
67
+ // here as more of a hook for functorch to use.
68
+ int64_t level() const {
69
+ return level_;
70
+ };
71
+ void set_level(int64_t level) {
72
+ level_ = level;
73
+ }
74
+ bool has_metadata_mutation() const {
75
+ return has_metadata_mutation_;
76
+ };
77
+
78
+ void mark_mutation() {
79
+ functional_storage_impl()->mark_mutation();
80
+ }
81
+ // Denotes a mutation that's hidden from autograd,
82
+ // e.g. for the purposes of passing a tensor to a triton kernel
83
+ void mark_mutation_hidden_from_autograd() {
84
+ functional_storage_impl()->mark_mutation_hidden_from_autograd();
85
+ }
86
+ void mark_mutation_during_no_grad_or_inference_mode() {
87
+ functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
88
+ }
89
+ // Are all the mutations happening to the tensor hidden from autograd
90
+ bool are_all_mutations_hidden_from_autograd() const {
91
+ return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
92
+ }
93
+ // Did all mutations happen under no_grad or inference_mode
94
+ // (We also need to ignore mutations fully hidden from autograd here)
95
+ bool are_all_mutations_under_no_grad_or_inference_mode() const {
96
+ return functional_storage_impl()
97
+ ->are_all_mutations_under_no_grad_or_inference_mode();
98
+ }
99
+
100
+ void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
101
+ is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
102
+ }
103
+
104
+ bool is_symbolic() const {
105
+ return is_symbolic_;
106
+ }
107
+
108
+ // Runs the forward_fn of every ViewMeta collected in the current instance
109
+ // to some other base.
110
+ Tensor apply_view_metas(const Tensor& base);
111
+
112
+ // Sync's the underlying tensor with its alias, if it's out of date. This
113
+ // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
114
+ // Replay the views (if any) to regenerate the current tensor off of the
115
+ // updated alias.
116
+ void sync_();
117
+ // Performs step (1) of the sync. This is its own public API because it's
118
+ // needed by view_inplace ops like transpose_. See Note [Functionalization
119
+ // Pass - Inplace View Ops]
120
+ void regenerate_from_base();
121
+ // Performs step (2) of the sync. This is its own public API because it's
122
+ // needed by functorch. functorch wants to make sure that all input tensors to
123
+ // a functionalized program have been properly synced so it can properly
124
+ // propagate mutations to inputs. It can't just call sync_(), because the
125
+ // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
126
+ // a noop. We use the reference count on storage_ to determine if the wrapper
127
+ // is aliased, and by the time functorch is ready to propagate updates to
128
+ // inputs, any intermediate views of the input created by the program will
129
+ // have been deallocated. This function also returns whether or not the base
130
+ // actually had any updates to apply.
131
+ bool apply_updates();
132
+ // Takes the current state of value_ and snapshots it, sending it as a pending
133
+ // update to the alias.
134
+ void commit_update();
135
+ // When any tensor is mutated, the tensor increments its alias's "generation".
136
+ // Separately, each tensor maintains its own "generation" counter, which is
137
+ // used to determine if it's up-to-date with its alias. The act of syncing a
138
+ // tensor will set a tensor's generation equal to its alias's generation.
139
+ bool is_up_to_date() const;
140
+ // Freezes the storage of this tensor, preventing subsequent mutations
141
+ void freeze_storage() const;
142
+ // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
143
+ // describing the series of view ops that ran to generate the current tensor
144
+ // from the base tensor. This method is used by inplace-view ops like
145
+ // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
146
+ // tensor by replaying the views off of the alias.
147
+ void mutate_view_meta(const at::functionalization::ViewMeta& meta);
148
+
149
+ // Custom implementation of self.set_(src)
150
+ void set__impl(const FunctionalTensorWrapper* other);
151
+
152
+ // Custom implementation of resize_storage_bytes_(self, new_size)
153
+ void storage_resize_(const c10::SymInt& new_size);
154
+
155
+ // Returns whether the current tensor's data was ever mutated
156
+ bool has_data_mutation();
157
+ //
158
+ // Returns whether the current FunctionalTensorWrapper
159
+ // experienced a set_() call.
160
+ bool was_storage_changed() {
161
+ return was_storage_changed_;
162
+ }
163
+
164
+ void set_storage_changed() {
165
+ was_storage_changed_ = true;
166
+ }
167
+
168
+ // A FunctionalTensor is considered a base if its not a view of another
169
+ // tensor.
170
+ bool isBaseTensor() const {
171
+ return view_metas_.empty();
172
+ }
173
+
174
+ c10::SymInt get_storage_size(bool before) {
175
+ return functional_storage_impl()->get_storage_size(before);
176
+ }
177
+
178
+ // Returns whether the FunctionalTensor experienced an
179
+ // untyped_storage().resize_() call
180
+ bool was_inductor_storage_resized() {
181
+ return functional_storage_impl()->was_inductor_storage_resized();
182
+ }
183
+
184
+ // The functionalization pass can be used to remove mutations.
185
+ // It does so by replacing any mutation op with it's corresponding
186
+ // out-of-place op, followed by a call to replace_(). e.g:
187
+ //
188
+ // a.add_(1)
189
+ //
190
+ // will turn into:
191
+ //
192
+ // tmp = a.add(1)
193
+ // a.replace_(tmp)
194
+ //
195
+ // replace_() swaps out the wrapped tensor, value_, with tmp.
196
+ void replace_(const Tensor& other, bool from_lazy_regenerate = false);
197
+
198
+ bool is_multi_output_view() {
199
+ return is_multi_output_view_;
200
+ }
201
+
202
+ // See Note[resize_() in functionalization pass]
203
+ void maybe_replace_storage(const Tensor& other);
204
+
205
+ // Replaces the storage with a new functional storage,
206
+ // and clears the view_metas_ stack.
207
+ // WARNING: Calling this function will sever the aliasing relationship between
208
+ // the current FunctionalTensorWrapper and any of its outstanding aliases.
209
+ // Please only call if you know what you're doing.
210
+ void _unsafe_reset_storage();
211
+
212
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
213
+ const c10::VariableVersion& version_counter,
214
+ bool allow_tensor_metadata_change) const override;
215
+
216
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
217
+ c10::VariableVersion&& version_counter,
218
+ bool allow_tensor_metadata_change) const override;
219
+
220
+ ~FunctionalTensorWrapper() override = default;
221
+
222
+ // FunctionalTensorWrapper overrides all custom size/stride function,
223
+ // so that if the inner tensor has a custom implementation
224
+ // we make sure to call that implementation.
225
+ at::IntArrayRef sizes_custom() const override;
226
+ at::IntArrayRef strides_custom() const override;
227
+ int64_t dim_custom() const override;
228
+ int64_t numel_custom() const override;
229
+ bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
230
+ c10::SymIntArrayRef sym_sizes_custom() const override;
231
+ c10::SymInt sym_size_custom(int64_t d) const override;
232
+ c10::SymIntArrayRef sym_strides_custom() const override;
233
+ c10::SymInt sym_storage_offset_custom() const override;
234
+ c10::Device device_custom() const override;
235
+ c10::Layout layout_impl() const override;
236
+
237
+ private:
238
+ const char* tensorimpl_type_name() const override;
239
+ void set_constructor_metadata();
240
+ functionalization::FunctionalStorageImpl* functional_storage_impl() const;
241
+
242
+ // This is used to re-implement shallow_copy_and_detach for
243
+ // FunctionalTensorWrapper. The implementation is identical, but we just need
244
+ // to return a subclass instead of a plain TensorImpl.
245
+ // TODO: maybe it's possible to arrange for that to happen automatically
246
+ // without an override here?
247
+ template <typename VariableVersion>
248
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
249
+ VariableVersion&& version_counter,
250
+ bool allow_tensor_metadata_change) const;
251
+
252
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
253
+ void copy_tensor_metadata_and_refresh(
254
+ const FunctionalTensorWrapper* src_impl,
255
+ FunctionalTensorWrapper* dest_impl,
256
+ const c10::VariableVersion& version_counter,
257
+ bool allow_tensor_metadata_change) const;
258
+
259
+ // Note that value is not taken by reference: internally, the wrapper will
260
+ // change the value tensor that it points to over time.
261
+ Tensor value_;
262
+ int64_t level_{};
263
+ // These two counters are used for identifying
264
+ // whether all the mutations on a given tensor are hidden from autograd or
265
+ // not. If we have an input mutation that is hidden from autograd, then once
266
+ // we convert the input mutation to a copy_() we know it will be safe to hide
267
+ // the copy_() from autograd as well.
268
+ bool has_metadata_mutation_ = false;
269
+ bool is_multi_output_view_ = false;
270
+ // Did the tensor experience a set_() call.
271
+ bool was_storage_changed_ = false;
272
+ // Did the tensor experience any view operation with symbolic int.
273
+ bool is_symbolic_ = false;
274
+
275
+ size_t generation_ = 0;
276
+ std::vector<at::functionalization::ViewMeta> view_metas_;
277
+
278
+ protected:
279
+ static void copy_tensor_metadata(
280
+ const FunctionalTensorWrapper* src_impl,
281
+ FunctionalTensorWrapper* dest_impl,
282
+ const c10::VariableVersion& version_counter,
283
+ bool allow_tensor_metadata_change);
284
+ };
285
+
286
+ // Utility functions for the functionalization pass.
287
+
288
+ namespace functionalization {
289
+ namespace impl {
290
+
291
+ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
292
+ const Tensor& tensor) {
293
+ auto functional_impl =
294
+ static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
295
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
296
+ return functional_impl;
297
+ }
298
+
299
+ TORCH_API bool isBaseTensor(const at::Tensor& tensor);
300
+
301
+ TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
302
+ TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
303
+ TORCH_API bool isFunctionalTensor(
304
+ const c10::List<std::optional<Tensor>>& t_list);
305
+ TORCH_API bool isFunctionalTensor(ITensorListRef list);
306
+
307
+ TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
308
+ TORCH_API std::optional<Tensor> to_functional_tensor(
309
+ const std::optional<Tensor>& tensor);
310
+ TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
311
+ const c10::List<std::optional<Tensor>>& t_list);
312
+ TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
313
+
314
+ TORCH_API void freeze_functional_tensor(const Tensor& tensor);
315
+
316
+ TORCH_API Tensor
317
+ from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
318
+ TORCH_API std::optional<Tensor> from_functional_tensor(
319
+ const std::optional<Tensor>& t,
320
+ bool assert_functional = true);
321
+ TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
322
+ const c10::List<std::optional<Tensor>>& t_list);
323
+ TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
324
+
325
+ TORCH_API void sync(const at::Tensor& t);
326
+ TORCH_API void sync(const std::optional<Tensor>& t);
327
+ TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
328
+ TORCH_API void sync(ITensorListRef t_list);
329
+
330
+ TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
331
+ TORCH_API void replace_(
332
+ const ITensorListRef functional_tensor,
333
+ ITensorListRef other);
334
+
335
+ TORCH_API void commit_update(const Tensor& functional_tensor);
336
+ TORCH_API void commit_update(ITensorListRef functional_tensor);
337
+
338
+ TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
339
+
340
+ TORCH_API void mark_mutation_hidden_from_autograd(
341
+ const Tensor& functional_tensor);
342
+
343
+ TORCH_API bool are_all_mutations_hidden_from_autograd(
344
+ const Tensor& functional_tensor);
345
+
346
+ TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
347
+ const Tensor& functional_tensor);
348
+
349
+ // These two methods are XLA-specific logic and are no-ops
350
+ // for the normal functionalization flow.
351
+ TORCH_API void propagate_xla_data(
352
+ const Tensor& functional_tensor,
353
+ const Tensor& other);
354
+ TORCH_API void propagate_xla_data(
355
+ const ITensorListRef functional_tensor,
356
+ ITensorListRef other);
357
+
358
+ TORCH_API void propagate_xla_data_direct(
359
+ const Tensor& tensor,
360
+ const Tensor& other);
361
+ TORCH_API void propagate_xla_data_direct(
362
+ const ITensorListRef tensor,
363
+ ITensorListRef other);
364
+
365
+ Tensor create_functional_tensor_with_view_meta(
366
+ const Tensor& view_to_wrap,
367
+ const Tensor& base,
368
+ functionalization::ViewMeta meta,
369
+ int64_t out_idx = 0);
370
+ std::vector<Tensor> create_functional_tensor_with_view_meta(
371
+ ITensorListRef view_to_wrap,
372
+ const Tensor& base,
373
+ const functionalization::ViewMeta& meta);
374
+
375
+ void mutate_view_meta(
376
+ const Tensor& self,
377
+ const functionalization::ViewMeta& meta);
378
+
379
+ void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
380
+ void set_sizes_strides_offset(
381
+ const std::vector<Tensor>& outs,
382
+ const std::vector<Tensor>& meta_outs);
383
+
384
+ // ~~~~~ TLS used in functionalization ~~~~~
385
+
386
+ TORCH_API bool getFunctionalizationReapplyViewsTLS();
387
+ TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
388
+
389
+ class TORCH_API FunctionalizationReapplyViewsGuard {
390
+ public:
391
+ FunctionalizationReapplyViewsGuard(bool reapply_views)
392
+ : prev_(getFunctionalizationReapplyViewsTLS()) {
393
+ setFunctionalizationReapplyViewsTLS(reapply_views);
394
+ }
395
+
396
+ ~FunctionalizationReapplyViewsGuard() {
397
+ setFunctionalizationReapplyViewsTLS(prev_);
398
+ }
399
+
400
+ FunctionalizationReapplyViewsGuard(
401
+ const FunctionalizationReapplyViewsGuard&) = delete;
402
+ FunctionalizationReapplyViewsGuard operator=(
403
+ const FunctionalizationReapplyViewsGuard&) = delete;
404
+ FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
405
+ delete;
406
+ FunctionalizationReapplyViewsGuard operator=(
407
+ FunctionalizationReapplyViewsGuard&&) = delete;
408
+
409
+ private:
410
+ bool prev_;
411
+ };
412
+
413
+ } // namespace impl
414
+
415
+ // Helper function to call an out-of-place composite aten kernel that may use
416
+ // mutations / views internally, and functionalize them.
417
+ TORCH_API void functionalize_op_helper(
418
+ const c10::OperatorHandle& op,
419
+ torch::jit::Stack* stack);
420
+
421
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
422
+ struct _functionalize_aten_op final {};
423
+
424
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
425
+ struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
426
+ static ReturnType call(
427
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
428
+ using FuncType = ReturnType(
429
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
430
+ auto op = c10::Dispatcher::singleton()
431
+ .findSchemaOrThrow(
432
+ (const char*)Op::name, (const char*)Op::overload_name)
433
+ .typed<FuncType>();
434
+
435
+ return c10::impl::BoxedKernelWrapper<FuncType>::call(
436
+ c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
437
+ op,
438
+ // BoxedKernelWrapper knows to ignore this keyset argument,
439
+ // because functionalize_op_helper doesn't take in a DispatchKeySet
440
+ c10::DispatchKeySet(),
441
+ args...);
442
+ }
443
+ };
444
+
445
+ template <class Op>
446
+ using functionalize_aten_op =
447
+ _functionalize_aten_op<Op, false, typename Op::schema>;
448
+
449
+ template <class Op>
450
+ using functionalize_aten_op_symint =
451
+ _functionalize_aten_op<Op, true, typename Op::schema>;
452
+
453
+ } // namespace functionalization
454
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/InferSize.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/DimVector.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <c10/core/SymIntArrayRef.h>
6
+ #include <c10/util/DimVector.h>
7
+ #include <optional>
8
+ #include <sstream>
9
+ #include <vector>
10
+
11
+ namespace at {
12
+
13
+ // Infers the size of a dim with size -1, if it exists. Also checks that new
14
+ // shape is compatible with the number of elements.
15
+ //
16
+ // templated to handle std::vector<int64_t> and DimVector use cases, see
17
+ // below
18
+ //
19
+ template <typename InputArrayRef, typename NumelType, typename ResultVec>
20
+ inline void infer_size_impl(
21
+ InputArrayRef shape,
22
+ NumelType numel,
23
+ ResultVec& res) {
24
+ NumelType newsize = 1;
25
+ // N.B. this is an index, not a sym dim!
26
+ std::optional<int64_t> infer_dim;
27
+ for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28
+ if (shape[dim] == -1) {
29
+ if (infer_dim) {
30
+ throw std::runtime_error("only one dimension can be inferred");
31
+ }
32
+ infer_dim = dim;
33
+ } else if (shape[dim] >= 0) {
34
+ newsize *= shape[dim];
35
+ } else {
36
+ AT_ERROR("invalid shape dimension ", shape[dim]);
37
+ }
38
+ }
39
+
40
+ if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
41
+ (infer_dim && newsize > 0 && numel % newsize == 0)) {
42
+ if (infer_dim) {
43
+ // We have a degree of freedom here to select the dimension size; follow
44
+ // NumPy semantics and just bail. However, a nice error message is needed
45
+ // because users often use `view` as a way to flatten & unflatten
46
+ // dimensions and will otherwise be confused why
47
+ // empty_tensor.view( 0, 0)
48
+ // works yet
49
+ // empty_tensor.view(-1, 0)
50
+ // doesn't.
51
+ TORCH_CHECK(
52
+ newsize != 0,
53
+ "cannot reshape tensor of 0 elements into shape ",
54
+ shape,
55
+ " because the unspecified dimension size -1 can be any "
56
+ "value and is ambiguous");
57
+ res[*infer_dim] = numel / newsize;
58
+ }
59
+ return;
60
+ }
61
+
62
+ std::ostringstream ss;
63
+ ss << "shape '" << shape << "' is invalid for input of size " << numel;
64
+ throw std::runtime_error(ss.str());
65
+ }
66
+
67
+ inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
68
+ auto res = shape.vec();
69
+ infer_size_impl(shape, numel, res);
70
+ return res;
71
+ }
72
+
73
+ inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
74
+ auto res = at::DimVector(shape);
75
+ infer_size_impl(shape, numel, res);
76
+ return res;
77
+ }
78
+
79
+ inline at::SymDimVector infer_size_dv(
80
+ c10::SymIntArrayRef shape,
81
+ c10::SymInt numel) {
82
+ auto res = at::SymDimVector(shape);
83
+ infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
84
+ shape, std::move(numel), res);
85
+ return res;
86
+ }
87
+
88
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/TensorOptions.h>
4
+
5
+ namespace at {
6
+
7
+ // Represents the initial TensorOptions, before the "defaults" are ever changed.
8
+ // This is designed to be used in library code, where the explicit devices,
9
+ // dtypes, etc. are known. NOTE: this is not a stable API.
10
+ inline TensorOptions initialTensorOptions() {
11
+ return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
12
+ false);
13
+ }
14
+
15
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/core/Layout.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/core/op_registration/op_registration.h>
4
+ #include <torch/library.h>
5
+
6
+ namespace at {
7
+
8
+ // If an operator doesn't have a batching rule implemented then we fallback
9
+ // to this implementation. The fallback only works on out-of-place operators
10
+ // that return only tensors with new memory. (e.g., no in-place operators, no
11
+ // view operations).
12
+ //
13
+ // The fallback effectively takes all of the BatchedTensors in `stack`, slices
14
+ // them, and runs `op` on all of the corresponding slices to produce slices
15
+ // of the outputs. The output slices then get `torch.stack`ed to create the
16
+ // final returns.
17
+ //
18
+ // The performance of the fallback is not very good because it introduces an
19
+ // extra copy from stacking the sliced outputs. Because of this, we prefer to
20
+ // write batching rules for operators whenever possible.
21
+ void batchedTensorForLoopFallback(
22
+ const c10::OperatorHandle& op,
23
+ torch::jit::Stack* stack);
24
+
25
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <bitset>
4
+
5
+ #include <ATen/ArrayRef.h>
6
+ #include <ATen/SmallVector.h>
7
+ #include <ATen/Tensor.h>
8
+
9
+ namespace at {
10
+
11
+ // We assume this in a few other places in the codebase,
12
+ // but there isn't a centralized definition.
13
+ constexpr int64_t kVmapMaxTensorDims = 64;
14
+
15
+ // The valid vmap levels range from [0, 64). This effectively means that we
16
+ // support a maximum of 64 nested vmaps.
17
+ constexpr int64_t kVmapNumLevels = 64;
18
+
19
+ // Store this number of elements of BatchDims on the stack. Most people will
20
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
21
+ constexpr int64_t kBatchDimsStackSize = 5;
22
+
23
+ // a BatchDim represents a "private" dimension on a Tensor created inside of
24
+ // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
25
+ // is being vmap'ed over and the `level` being an identifier for which vmap
26
+ // said dimension was created inside. The `dim` corresponds to a "physical
27
+ // dim" - it is a dimension index on the underlying physical tensor that is
28
+ // being vmapped over.
29
+ struct BatchDim {
30
+ BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
31
+ int64_t dim() const {
32
+ return dim_;
33
+ }
34
+ int64_t level() const {
35
+ return level_;
36
+ }
37
+
38
+ private:
39
+ int64_t dim_;
40
+ int64_t level_;
41
+ };
42
+
43
+ using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
44
+ using BatchDimsRef = ArrayRef<BatchDim>;
45
+
46
+ // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
47
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
48
+ // BatchedTensorImpl.
49
+ //
50
+ // The batch dimensions are treated as being "private"; they are not
51
+ // user-visible. For example, in the following Tensor,
52
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
53
+ // dimensions 0 and 1 are batch dimensions.
54
+ //
55
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
56
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
57
+ // tensor.
58
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
59
+ explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
60
+
61
+ // Returns a reference to BatchDims that represent which dimensions of this
62
+ // tensor are private.
63
+ BatchDimsRef bdims() const {
64
+ return bdims_;
65
+ }
66
+
67
+ // BatchedTensorImpl wraps a Tensor
68
+ const Tensor& value() const {
69
+ return value_;
70
+ };
71
+
72
+ // Given a public dimension index, return the dimension index in the
73
+ // underlying value() tensor. For example, if we have
74
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
75
+ // dim=2)])
76
+ // bt.actualDim(0) -> 1
77
+ // bt.actualDim(1) -> 3
78
+ // bt.actualDim(2) -> Error
79
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
80
+
81
+ // We have to override this because we opted into CustomStrides
82
+ IntArrayRef strides_custom() const override;
83
+ // Override a bunch of methods inherited from TensorImpl to return error
84
+ // messages.
85
+ bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
86
+ void set_size(int64_t dim, int64_t new_size) override;
87
+ void set_stride(int64_t dim, int64_t new_stride) override;
88
+ void set_storage_offset(int64_t storage_offset) override;
89
+ #ifdef DEBUG
90
+ bool has_storage() const override;
91
+ #endif
92
+
93
+ private:
94
+ // see NOTE: [BatchedTensorImpl levels invariant]
95
+ void checkInvariants() const;
96
+ const char* tensorimpl_type_name() const override;
97
+
98
+ Tensor value_;
99
+
100
+ // Note: [BatchedTensorImpl levels invariant]
101
+ // There is an invariant that the BatchDims must be stored in increasing
102
+ // `level` order. That is, for i < j, bdims_[i].level must be less than
103
+ // bdims_[j].level.
104
+ BatchDims bdims_;
105
+ };
106
+
107
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
108
+ // BatchedTensorImpl.
109
+ inline bool isBatchedTensor(const Tensor& tensor) {
110
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
111
+ }
112
+
113
+ // It is unsafe to call this on a Tensor that is not backed by a
114
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
115
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
116
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
117
+ }
118
+
119
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
120
+ if (!isBatchedTensor(tensor)) {
121
+ return nullptr;
122
+ }
123
+ return unsafeGetBatchedImpl(tensor);
124
+ }
125
+
126
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
127
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
128
+ BatchDimsRef bdims) {
129
+ std::bitset<kVmapMaxTensorDims> is_bdim;
130
+ for (const auto& bdim : bdims) {
131
+ is_bdim.set(bdim.dim());
132
+ }
133
+ return is_bdim;
134
+ }
135
+
136
+ // Creates a bitset for all of the levels present in `bdims`
137
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
138
+ std::bitset<kVmapNumLevels> result;
139
+ for (const auto& bdim : bdims) {
140
+ result.set(bdim.level());
141
+ }
142
+ return result;
143
+ }
144
+
145
+ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
146
+ out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
147
+ return out;
148
+ }
149
+
150
+ // Use this to construct a BatchedTensor from a regular Tensor
151
+ TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
152
+
153
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
154
+ TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
155
+
156
+ // Checks if an inplace operation on self and other is "vmap compatible".
157
+ // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
158
+ TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
159
+
160
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapMode.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/impl/LocalDispatchKeySet.h>
4
+
5
+ namespace at::impl {
6
+
7
+ // VmapMode contains a thread local count of how many nested vmaps
8
+ // we are currently inside. That number is known as the `vmap level`.
9
+ // VmapMode is used in the implementation of the Python `torch.vmap` API.
10
+ //
11
+ // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
12
+
13
+ struct TORCH_API VmapMode {
14
+ // Returns the vmap level, aka the count of how many nested vmaps we're in.
15
+ static int64_t current_vmap_level();
16
+
17
+ // Increment the count of nested vmaps. If this causes the vmap level to be
18
+ // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
19
+ static int64_t increment_nesting();
20
+
21
+ // Decrements the count of nested vmaps. If this causes the vmap level to be
22
+ // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
23
+ static int64_t decrement_nesting();
24
+ };
25
+
26
+ } // namespace at::impl
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapTransforms.h ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/LegacyBatchedTensorImpl.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at {
7
+
8
+ // This file contains abstractions used for transforming *logical* vmap
9
+ // arguments into *physical* arguments. (Keep reading for definitions of these
10
+ // terms).
11
+
12
+ // NOTE: [Logical vs physical args]
13
+ // Consider the following vmap.
14
+ // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
15
+ // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
16
+ // with batch dims 0 and 2:
17
+ // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
18
+ //
19
+ // We say the *logical* view of the tensor has size [3] -- tensors inside
20
+ // `func` appear to have size [3].
21
+ // However, the *physical* underlying tensor (the one passed to vmap) has size
22
+ // [2, 3, 4].
23
+ //
24
+ // This notion of logical vs physical also extends to non-tensor arguments.
25
+ // Consider the previous tensor; let's assume the user called
26
+ // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
27
+ // dimension they are reducing over is dim 0 but the physical dim is dim 1
28
+ // (the first non-batch dimension)
29
+
30
+ // Forward declared; see NOTE: [What is a VmapPhysicalView?]
31
+ struct VmapPhysicalView;
32
+
33
+ // Most PyTorch operators take 4 or fewer inputs.
34
+ constexpr int64_t kVmapTransformStaticInputSize = 4;
35
+ using VmapPhysicalViewVec =
36
+ SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
37
+
38
+ // Pytorch generally advertises good performance for <= 5 dims.
39
+ // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
40
+ // dimensions to get 8. Adjust this number as necessary
41
+ constexpr int64_t kVmapStaticDimVecSize = 8;
42
+ using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
43
+ using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
44
+
45
+ // NOTE: [What is an VmapTransform?]
46
+ // An *VmapTransform* converts logical views of tensors to physical views.
47
+ //
48
+ // Batching rules use VmapTransforms to convert logical arguments to
49
+ // physical arguments, then call one or more at:: operator that handles the
50
+ // physical arguments, and then converts the physical result back to a logical
51
+ // argument.
52
+
53
+ // VmapTransform for operators that take tensors with multiple batch dims.
54
+ // Given one or more logical views on Tensors, `logicalToPhysical`
55
+ // permutes all of the batch dims to the front of the tensor, aligns
56
+ // and expands the batch dims to match each other (according to their `level`),
57
+ // and returns a VmapPhysicalView on the tensor(s).
58
+ struct TORCH_API MultiBatchVmapTransform {
59
+ static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
60
+ static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
61
+ };
62
+
63
+ // VmapTransform for operators that broadcast all inputs.
64
+ // Given some logical views on Tensors, `logicalToPhysical`:
65
+ // - permutes all of the batch dims to the front of the tensors
66
+ // - aligns all the batch dims to the collective levels of all of the tensors.
67
+ // If a tensor does not have a batch dim for a vmap level, then it receives
68
+ // a size-one dimension for said level.
69
+ // - aligns the non-batch dims to have the same dimensionality, adding extra
70
+ // size-1 dimensions in between the batch dimensions and the non-batch
71
+ // dimensions so that the batch dimensions are lined up from the right.
72
+ //
73
+ // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
74
+ // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
75
+ // tensors of size (B, 1, 2) and (B, 3, 2).
76
+ //
77
+ // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
78
+ // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
79
+ // actually *need* to return a tensor of size (1, 2) for the second tensor
80
+ // because the broadcasting operation takes care of that for us, but we do
81
+ // it anyways to keep things simple.
82
+ struct TORCH_API BroadcastingVmapTransform {
83
+ static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
84
+ };
85
+
86
+ // Forward declared, if you're reading this file head to toe, don't worry about
87
+ // it yet.
88
+ struct VmapPhysicalToLogicalMap;
89
+
90
+ // NOTE: [What is a VmapPhysicalView?]
91
+ // VmapPhysicalView represents a physical view on a Tensor.
92
+ //
93
+ // One can use it to further convert logical dimension indices, logical shapes,
94
+ // and more to their physical variants, or convert a new (physical) tensor into
95
+ // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
96
+ //
97
+ // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
98
+ // the front and some levels that correspond to said batch dimensions.
99
+ //
100
+ // The levels bitset specifies which vmap levels correspond to the batch
101
+ // dimensions at the front of the tensor. In particular, the number of set bits
102
+ // corresponds to the number of batch dimensions on `tensor` and the rightmost
103
+ // bit of `levels` specifies the maximum number of nested vmaps we are in at
104
+ // this point in time.
105
+ // For example, given:
106
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
107
+ //
108
+ // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
109
+ // than or equal to 3.
110
+ // bitset: 010100
111
+ // ^
112
+ // |
113
+ // levels: 012345
114
+ struct TORCH_API VmapPhysicalView {
115
+ VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
116
+ : levels_(levels), tensor_(std::move(tensor)) {
117
+ TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
118
+ }
119
+
120
+ Tensor& tensor() {
121
+ return tensor_;
122
+ }
123
+ const Tensor& tensor() const {
124
+ return tensor_;
125
+ }
126
+
127
+ // Maps logical dim indices to physical dim indices. Also does dim wrapping.
128
+ //
129
+ // For example, given:
130
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
131
+ //
132
+ // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
133
+ // This is because the size of levels tell us that the first two dimensions
134
+ // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
135
+ // a physical dim of `n + 2`.
136
+ VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
137
+ int64_t getPhysicalDim(int64_t logical_dim) const;
138
+
139
+ // Returns a VmapPhysicalToLogicalMap object. This can be used for
140
+ // mapping a physical tensor to a new logical tensor (BatchedTensor)
141
+ VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
142
+
143
+ // Maps a logical shape to a physical shape by pre-pending the batch
144
+ // sizes to the logical shape.
145
+ VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
146
+
147
+ int64_t numBatchDims() const;
148
+
149
+ private:
150
+ int64_t numLogicalDims() const;
151
+
152
+ std::bitset<kVmapNumLevels> levels_;
153
+ Tensor tensor_;
154
+ };
155
+
156
+ // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
157
+ // to a logical one (BatchedTensor). It holds some levels that are used to do
158
+ // the mapping and assumes that the batch dimensions in the physical tensor all
159
+ // occur at the front of the tensor.
160
+ struct TORCH_API VmapPhysicalToLogicalMap {
161
+ VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
162
+ : levels_(levels) {}
163
+
164
+ // Maps a physical tensor to a new logical tensor (BatchedTensor).
165
+ // Assumes that all of the "batch dimensions" are at the front
166
+ // of the physical tensor. For example, given:
167
+ // - x = rank-4 Tensor with size 2, 3, 5, 7
168
+ // - levels = (2, 4)
169
+ // Returns:
170
+ // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
171
+ Tensor apply(const Tensor& physical_tensor) const;
172
+
173
+ // Given a vector of physical tensors,
174
+ // 1. maps each tensor to a new logical tensor. Assumes that all of the
175
+ // "batch dimensions" are at the front of the physical tensors.
176
+ // 2. stores the new logical tensors back into the passed-in vector. This is
177
+ // to avoid additional dynamic allocations.
178
+ void applyInplace(std::vector<Tensor>& physical_tensors) const;
179
+
180
+ std::bitset<kVmapNumLevels> levels_;
181
+ };
182
+
183
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Allocator.h>
4
+ #include <c10/util/string_view.h>
5
+
6
+ namespace at {
7
+
8
+ enum MappedAllocatorModes {
9
+ ALLOCATOR_MAPPED_SHARED = 1,
10
+ ALLOCATOR_MAPPED_SHAREDMEM = 2,
11
+ ALLOCATOR_MAPPED_EXCLUSIVE = 4,
12
+ ALLOCATOR_MAPPED_NOCREATE = 8,
13
+ ALLOCATOR_MAPPED_KEEPFD = 16,
14
+ ALLOCATOR_MAPPED_FROMFD = 32,
15
+ ALLOCATOR_MAPPED_UNLINK = 64
16
+ };
17
+
18
+ // Sentinel value/type to help distinguish the file descriptor constructor from
19
+ // the non-file descriptor constructor
20
+ enum WithFd { WITH_FD };
21
+
22
+ TORCH_API std::string NewProcessWideShmHandle();
23
+
24
+ class TORCH_API MapAllocator {
25
+ public:
26
+ MapAllocator(c10::string_view filename, int flags, size_t size);
27
+ MapAllocator(
28
+ WithFd,
29
+ c10::string_view filename,
30
+ int fd,
31
+ int flags,
32
+ size_t size);
33
+ MapAllocator(const MapAllocator&) = delete;
34
+ MapAllocator& operator=(const MapAllocator&) = delete;
35
+ MapAllocator(MapAllocator&&) = delete;
36
+ MapAllocator& operator=(MapAllocator&&) = delete;
37
+
38
+ const char* filename() const {
39
+ return filename_.c_str();
40
+ }
41
+ int fd() const {
42
+ #ifdef _WIN32
43
+ TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
44
+ #else
45
+ return fd_;
46
+ #endif
47
+ }
48
+ ptrdiff_t size() const {
49
+ return size_;
50
+ }
51
+ // Return a pointer to the actual data for this allocator
52
+ // (in the case of the refcounted allocator, this is offset
53
+ // from the base pointer.)
54
+ virtual void* data() const {
55
+ return base_ptr_;
56
+ }
57
+
58
+ int flags() const {
59
+ return flags_;
60
+ }
61
+
62
+ static MapAllocator* fromDataPtr(const at::DataPtr&);
63
+ static at::DataPtr makeDataPtr(
64
+ c10::string_view filename,
65
+ int flags,
66
+ size_t size,
67
+ size_t* actual_size_out);
68
+ static at::DataPtr makeDataPtr(
69
+ WithFd,
70
+ const char* filename,
71
+ int fd,
72
+ int flags,
73
+ size_t size,
74
+ size_t* actual_size_out);
75
+
76
+ // Closes the data. Helps us avoid destructor shenanigans
77
+ virtual void close();
78
+
79
+ // This is very dangerous. You have to redefine this destructor for each
80
+ // subclass
81
+ virtual ~MapAllocator();
82
+
83
+ protected:
84
+ bool closed_ = false;
85
+ std::string filename_;
86
+ int flags_ = 0;
87
+ ptrdiff_t size_; /* mapped size */
88
+ #ifdef _WIN32
89
+ void* handle_;
90
+ void* event_;
91
+ std::string eventname_;
92
+ #else
93
+ int fd_ = -1;
94
+ #endif
95
+ void* base_ptr_ = nullptr;
96
+ };
97
+
98
+ // Base-from-member idiom
99
+ struct TORCH_API RefcountedMapAllocatorArgCheck {
100
+ RefcountedMapAllocatorArgCheck(int flags);
101
+ };
102
+
103
+ class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
104
+ public MapAllocator {
105
+ public:
106
+ RefcountedMapAllocator(const char* filename, int flags, size_t size);
107
+ RefcountedMapAllocator(
108
+ WithFd,
109
+ const char* filename,
110
+ int fd,
111
+ int flags,
112
+ size_t size);
113
+
114
+ static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
115
+ static at::DataPtr makeDataPtr(
116
+ const char* filename,
117
+ int flags,
118
+ size_t size,
119
+ size_t* actual_size_out);
120
+ static at::DataPtr makeDataPtr(
121
+ WithFd,
122
+ const char* filename,
123
+ int fd,
124
+ int flags,
125
+ size_t size,
126
+ size_t* actual_size_out);
127
+
128
+ void* data() const override;
129
+
130
+ void incref();
131
+ int decref();
132
+ void close() override;
133
+
134
+ ~RefcountedMapAllocator() override {
135
+ RefcountedMapAllocator::close();
136
+ }
137
+
138
+ protected:
139
+ void checkFlags();
140
+ void initializeAlloc();
141
+ };
142
+
143
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Utils.h>
3
+ #include <c10/util/ArrayRef.h>
4
+
5
+ namespace at {
6
+ /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
7
+ /// we can easily view it as a multidimensional array.
8
+ ///
9
+ /// Like ArrayRef, this class does not own the underlying data, it is expected
10
+ /// to be used in situations where the data resides in some other buffer.
11
+ ///
12
+ /// This is intended to be trivially copyable, so it should be passed by
13
+ /// value.
14
+ ///
15
+ /// For now, 2D only (so the copies are actually cheap, without having
16
+ /// to write a SmallVector class) and contiguous only (so we can
17
+ /// return non-strided ArrayRef on index).
18
+ ///
19
+ /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
20
+ template <typename T>
21
+ class MatrixRef {
22
+ public:
23
+ typedef size_t size_type;
24
+
25
+ private:
26
+ /// Underlying ArrayRef
27
+ ArrayRef<T> arr;
28
+
29
+ /// Stride of dim 0 (outer dimension)
30
+ size_type stride0;
31
+
32
+ // Stride of dim 1 is assumed to be 1
33
+
34
+ public:
35
+ /// Construct an empty Matrixref.
36
+ /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
37
+
38
+ /// Construct an MatrixRef from an ArrayRef and outer stride.
39
+ /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
40
+ : arr(arr), stride0(stride0) {
41
+ TORCH_CHECK(
42
+ arr.size() % stride0 == 0,
43
+ "MatrixRef: ArrayRef size ",
44
+ arr.size(),
45
+ " not divisible by stride ",
46
+ stride0)
47
+ }
48
+
49
+ /// @}
50
+ /// @name Simple Operations
51
+ /// @{
52
+
53
+ /// empty - Check if the matrix is empty.
54
+ bool empty() const {
55
+ return arr.empty();
56
+ }
57
+
58
+ const T* data() const {
59
+ return arr.data();
60
+ }
61
+
62
+ /// size - Get size a dimension
63
+ size_t size(size_t dim) const {
64
+ if (dim == 0) {
65
+ return arr.size() / stride0;
66
+ } else if (dim == 1) {
67
+ return stride0;
68
+ } else {
69
+ TORCH_CHECK(
70
+ 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
71
+ }
72
+ }
73
+
74
+ size_t numel() const {
75
+ return arr.size();
76
+ }
77
+
78
+ /// equals - Check for element-wise equality.
79
+ bool equals(MatrixRef RHS) const {
80
+ return stride0 == RHS.stride0 && arr.equals(RHS.arr);
81
+ }
82
+
83
+ /// @}
84
+ /// @name Operator Overloads
85
+ /// @{
86
+ ArrayRef<T> operator[](size_t Index) const {
87
+ return arr.slice(Index * stride0, stride0);
88
+ }
89
+
90
+ /// Disallow accidental assignment from a temporary.
91
+ ///
92
+ /// The declaration here is extra complicated so that "arrayRef = {}"
93
+ /// continues to select the move assignment operator.
94
+ template <typename U>
95
+ std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
96
+ U&& Temporary) = delete;
97
+
98
+ /// Disallow accidental assignment from a temporary.
99
+ ///
100
+ /// The declaration here is extra complicated so that "arrayRef = {}"
101
+ /// continues to select the move assignment operator.
102
+ template <typename U>
103
+ std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
104
+ std::initializer_list<U>) = delete;
105
+ };
106
+
107
+ } // end namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable std::optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ #include <ATen/MetaFunctions_inl.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from MethodOperators.h
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ // Forward declarations of any types needed in the operator signatures.
14
+ // We can't directly include these classes because it will cause circular include dependencies.
15
+ // This file is included by TensorBody.h, which defines the Tensor class.
16
+ #include <ATen/core/ATen_fwd.h>
17
+
18
+ #include <ATen/ops/_addmm_activation_ops.h>
19
+ #include <ATen/ops/_autocast_to_full_precision_ops.h>
20
+ #include <ATen/ops/_autocast_to_reduced_precision_ops.h>
21
+ #include <ATen/ops/_backward_ops.h>
22
+ #include <ATen/ops/_coalesced_ops.h>
23
+ #include <ATen/ops/_conj_ops.h>
24
+ #include <ATen/ops/_conj_physical_ops.h>
25
+ #include <ATen/ops/_dimI_ops.h>
26
+ #include <ATen/ops/_dimV_ops.h>
27
+ #include <ATen/ops/_fw_primal_ops.h>
28
+ #include <ATen/ops/_indices_ops.h>
29
+ #include <ATen/ops/_is_all_true_ops.h>
30
+ #include <ATen/ops/_is_any_true_ops.h>
31
+ #include <ATen/ops/_is_zerotensor_ops.h>
32
+ #include <ATen/ops/_lazy_clone_ops.h>
33
+ #include <ATen/ops/_neg_view_ops.h>
34
+ #include <ATen/ops/_nested_tensor_size_ops.h>
35
+ #include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
36
+ #include <ATen/ops/_nested_tensor_strides_ops.h>
37
+ #include <ATen/ops/_nnz_ops.h>
38
+ #include <ATen/ops/_reshape_alias_ops.h>
39
+ #include <ATen/ops/_sparse_mask_projection_ops.h>
40
+ #include <ATen/ops/_to_dense_ops.h>
41
+ #include <ATen/ops/_to_sparse_bsc_ops.h>
42
+ #include <ATen/ops/_to_sparse_bsr_ops.h>
43
+ #include <ATen/ops/_to_sparse_csc_ops.h>
44
+ #include <ATen/ops/_to_sparse_csr_ops.h>
45
+ #include <ATen/ops/_to_sparse_ops.h>
46
+ #include <ATen/ops/_values_ops.h>
47
+ #include <ATen/ops/_version_ops.h>
48
+ #include <ATen/ops/abs_ops.h>
49
+ #include <ATen/ops/absolute_ops.h>
50
+ #include <ATen/ops/acos_ops.h>
51
+ #include <ATen/ops/acosh_ops.h>
52
+ #include <ATen/ops/add_ops.h>
53
+ #include <ATen/ops/addbmm_ops.h>
54
+ #include <ATen/ops/addcdiv_ops.h>
55
+ #include <ATen/ops/addcmul_ops.h>
56
+ #include <ATen/ops/addmm_ops.h>
57
+ #include <ATen/ops/addmv_ops.h>
58
+ #include <ATen/ops/addr_ops.h>
59
+ #include <ATen/ops/adjoint_ops.h>
60
+ #include <ATen/ops/alias_ops.h>
61
+ #include <ATen/ops/align_as_ops.h>
62
+ #include <ATen/ops/align_to_ops.h>
63
+ #include <ATen/ops/all_ops.h>
64
+ #include <ATen/ops/allclose_ops.h>
65
+ #include <ATen/ops/amax_ops.h>
66
+ #include <ATen/ops/amin_ops.h>
67
+ #include <ATen/ops/aminmax_ops.h>
68
+ #include <ATen/ops/and_ops.h>
69
+ #include <ATen/ops/angle_ops.h>
70
+ #include <ATen/ops/any_ops.h>
71
+ #include <ATen/ops/arccos_ops.h>
72
+ #include <ATen/ops/arccosh_ops.h>
73
+ #include <ATen/ops/arcsin_ops.h>
74
+ #include <ATen/ops/arcsinh_ops.h>
75
+ #include <ATen/ops/arctan2_ops.h>
76
+ #include <ATen/ops/arctan_ops.h>
77
+ #include <ATen/ops/arctanh_ops.h>
78
+ #include <ATen/ops/argmax_ops.h>
79
+ #include <ATen/ops/argmin_ops.h>
80
+ #include <ATen/ops/argsort_ops.h>
81
+ #include <ATen/ops/argwhere_ops.h>
82
+ #include <ATen/ops/as_strided_ops.h>
83
+ #include <ATen/ops/as_strided_scatter_ops.h>
84
+ #include <ATen/ops/asin_ops.h>
85
+ #include <ATen/ops/asinh_ops.h>
86
+ #include <ATen/ops/atan2_ops.h>
87
+ #include <ATen/ops/atan_ops.h>
88
+ #include <ATen/ops/atanh_ops.h>
89
+ #include <ATen/ops/baddbmm_ops.h>
90
+ #include <ATen/ops/bernoulli_ops.h>
91
+ #include <ATen/ops/bincount_ops.h>
92
+ #include <ATen/ops/bitwise_and_ops.h>
93
+ #include <ATen/ops/bitwise_left_shift_ops.h>
94
+ #include <ATen/ops/bitwise_not_ops.h>
95
+ #include <ATen/ops/bitwise_or_ops.h>
96
+ #include <ATen/ops/bitwise_right_shift_ops.h>
97
+ #include <ATen/ops/bitwise_xor_ops.h>
98
+ #include <ATen/ops/bmm_ops.h>
99
+ #include <ATen/ops/broadcast_to_ops.h>
100
+ #include <ATen/ops/cauchy_ops.h>
101
+ #include <ATen/ops/ccol_indices_ops.h>
102
+ #include <ATen/ops/ceil_ops.h>
103
+ #include <ATen/ops/chalf_ops.h>
104
+ #include <ATen/ops/cholesky_inverse_ops.h>
105
+ #include <ATen/ops/cholesky_ops.h>
106
+ #include <ATen/ops/cholesky_solve_ops.h>
107
+ #include <ATen/ops/chunk_ops.h>
108
+ #include <ATen/ops/clamp_max_ops.h>
109
+ #include <ATen/ops/clamp_min_ops.h>
110
+ #include <ATen/ops/clamp_ops.h>
111
+ #include <ATen/ops/clip_ops.h>
112
+ #include <ATen/ops/clone_ops.h>
113
+ #include <ATen/ops/coalesce_ops.h>
114
+ #include <ATen/ops/col_indices_ops.h>
115
+ #include <ATen/ops/conj_ops.h>
116
+ #include <ATen/ops/conj_physical_ops.h>
117
+ #include <ATen/ops/contiguous_ops.h>
118
+ #include <ATen/ops/copy_ops.h>
119
+ #include <ATen/ops/copysign_ops.h>
120
+ #include <ATen/ops/corrcoef_ops.h>
121
+ #include <ATen/ops/cos_ops.h>
122
+ #include <ATen/ops/cosh_ops.h>
123
+ #include <ATen/ops/count_nonzero_ops.h>
124
+ #include <ATen/ops/cov_ops.h>
125
+ #include <ATen/ops/cross_ops.h>
126
+ #include <ATen/ops/crow_indices_ops.h>
127
+ #include <ATen/ops/cummax_ops.h>
128
+ #include <ATen/ops/cummin_ops.h>
129
+ #include <ATen/ops/cumprod_ops.h>
130
+ #include <ATen/ops/cumsum_ops.h>
131
+ #include <ATen/ops/data_ops.h>
132
+ #include <ATen/ops/deg2rad_ops.h>
133
+ #include <ATen/ops/dense_dim_ops.h>
134
+ #include <ATen/ops/dequantize_ops.h>
135
+ #include <ATen/ops/det_ops.h>
136
+ #include <ATen/ops/detach_ops.h>
137
+ #include <ATen/ops/diag_embed_ops.h>
138
+ #include <ATen/ops/diag_ops.h>
139
+ #include <ATen/ops/diagflat_ops.h>
140
+ #include <ATen/ops/diagonal_ops.h>
141
+ #include <ATen/ops/diagonal_scatter_ops.h>
142
+ #include <ATen/ops/diff_ops.h>
143
+ #include <ATen/ops/digamma_ops.h>
144
+ #include <ATen/ops/dist_ops.h>
145
+ #include <ATen/ops/div_ops.h>
146
+ #include <ATen/ops/divide_ops.h>
147
+ #include <ATen/ops/dot_ops.h>
148
+ #include <ATen/ops/dsplit_ops.h>
149
+ #include <ATen/ops/eq_ops.h>
150
+ #include <ATen/ops/equal_ops.h>
151
+ #include <ATen/ops/erf_ops.h>
152
+ #include <ATen/ops/erfc_ops.h>
153
+ #include <ATen/ops/erfinv_ops.h>
154
+ #include <ATen/ops/exp2_ops.h>
155
+ #include <ATen/ops/exp_ops.h>
156
+ #include <ATen/ops/expand_as_ops.h>
157
+ #include <ATen/ops/expand_ops.h>
158
+ #include <ATen/ops/expm1_ops.h>
159
+ #include <ATen/ops/exponential_ops.h>
160
+ #include <ATen/ops/fill_diagonal_ops.h>
161
+ #include <ATen/ops/fill_ops.h>
162
+ #include <ATen/ops/fix_ops.h>
163
+ #include <ATen/ops/flatten_ops.h>
164
+ #include <ATen/ops/flip_ops.h>
165
+ #include <ATen/ops/fliplr_ops.h>
166
+ #include <ATen/ops/flipud_ops.h>
167
+ #include <ATen/ops/float_power_ops.h>
168
+ #include <ATen/ops/floor_divide_ops.h>
169
+ #include <ATen/ops/floor_ops.h>
170
+ #include <ATen/ops/fmax_ops.h>
171
+ #include <ATen/ops/fmin_ops.h>
172
+ #include <ATen/ops/fmod_ops.h>
173
+ #include <ATen/ops/frac_ops.h>
174
+ #include <ATen/ops/frexp_ops.h>
175
+ #include <ATen/ops/gather_ops.h>
176
+ #include <ATen/ops/gcd_ops.h>
177
+ #include <ATen/ops/ge_ops.h>
178
+ #include <ATen/ops/geometric_ops.h>
179
+ #include <ATen/ops/geqrf_ops.h>
180
+ #include <ATen/ops/ger_ops.h>
181
+ #include <ATen/ops/greater_equal_ops.h>
182
+ #include <ATen/ops/greater_ops.h>
183
+ #include <ATen/ops/gt_ops.h>
184
+ #include <ATen/ops/hardshrink_backward_ops.h>
185
+ #include <ATen/ops/hardshrink_ops.h>
186
+ #include <ATen/ops/heaviside_ops.h>
187
+ #include <ATen/ops/histc_ops.h>
188
+ #include <ATen/ops/histogram_ops.h>
189
+ #include <ATen/ops/hsplit_ops.h>
190
+ #include <ATen/ops/hypot_ops.h>
191
+ #include <ATen/ops/i0_ops.h>
192
+ #include <ATen/ops/igamma_ops.h>
193
+ #include <ATen/ops/igammac_ops.h>
194
+ #include <ATen/ops/index_add_ops.h>
195
+ #include <ATen/ops/index_copy_ops.h>
196
+ #include <ATen/ops/index_fill_ops.h>
197
+ #include <ATen/ops/index_ops.h>
198
+ #include <ATen/ops/index_put_ops.h>
199
+ #include <ATen/ops/index_reduce_ops.h>
200
+ #include <ATen/ops/index_select_ops.h>
201
+ #include <ATen/ops/indices_ops.h>
202
+ #include <ATen/ops/inner_ops.h>
203
+ #include <ATen/ops/int_repr_ops.h>
204
+ #include <ATen/ops/inverse_ops.h>
205
+ #include <ATen/ops/is_coalesced_ops.h>
206
+ #include <ATen/ops/is_complex_ops.h>
207
+ #include <ATen/ops/is_conj_ops.h>
208
+ #include <ATen/ops/is_distributed_ops.h>
209
+ #include <ATen/ops/is_floating_point_ops.h>
210
+ #include <ATen/ops/is_inference_ops.h>
211
+ #include <ATen/ops/is_leaf_ops.h>
212
+ #include <ATen/ops/is_neg_ops.h>
213
+ #include <ATen/ops/is_nonzero_ops.h>
214
+ #include <ATen/ops/is_pinned_ops.h>
215
+ #include <ATen/ops/is_same_size_ops.h>
216
+ #include <ATen/ops/is_set_to_ops.h>
217
+ #include <ATen/ops/is_signed_ops.h>
218
+ #include <ATen/ops/isclose_ops.h>
219
+ #include <ATen/ops/isfinite_ops.h>
220
+ #include <ATen/ops/isinf_ops.h>
221
+ #include <ATen/ops/isnan_ops.h>
222
+ #include <ATen/ops/isneginf_ops.h>
223
+ #include <ATen/ops/isposinf_ops.h>
224
+ #include <ATen/ops/isreal_ops.h>
225
+ #include <ATen/ops/istft_ops.h>
226
+ #include <ATen/ops/item_ops.h>
227
+ #include <ATen/ops/kron_ops.h>
228
+ #include <ATen/ops/kthvalue_ops.h>
229
+ #include <ATen/ops/lcm_ops.h>
230
+ #include <ATen/ops/ldexp_ops.h>
231
+ #include <ATen/ops/le_ops.h>
232
+ #include <ATen/ops/lerp_ops.h>
233
+ #include <ATen/ops/less_equal_ops.h>
234
+ #include <ATen/ops/less_ops.h>
235
+ #include <ATen/ops/lgamma_ops.h>
236
+ #include <ATen/ops/log10_ops.h>
237
+ #include <ATen/ops/log1p_ops.h>
238
+ #include <ATen/ops/log2_ops.h>
239
+ #include <ATen/ops/log_normal_ops.h>
240
+ #include <ATen/ops/log_ops.h>
241
+ #include <ATen/ops/log_softmax_ops.h>
242
+ #include <ATen/ops/logaddexp2_ops.h>
243
+ #include <ATen/ops/logaddexp_ops.h>
244
+ #include <ATen/ops/logcumsumexp_ops.h>
245
+ #include <ATen/ops/logdet_ops.h>
246
+ #include <ATen/ops/logical_and_ops.h>
247
+ #include <ATen/ops/logical_not_ops.h>
248
+ #include <ATen/ops/logical_or_ops.h>
249
+ #include <ATen/ops/logical_xor_ops.h>
250
+ #include <ATen/ops/logit_ops.h>
251
+ #include <ATen/ops/logsumexp_ops.h>
252
+ #include <ATen/ops/lshift_ops.h>
253
+ #include <ATen/ops/lt_ops.h>
254
+ #include <ATen/ops/lu_solve_ops.h>
255
+ #include <ATen/ops/mH_ops.h>
256
+ #include <ATen/ops/mT_ops.h>
257
+ #include <ATen/ops/masked_fill_ops.h>
258
+ #include <ATen/ops/masked_scatter_ops.h>
259
+ #include <ATen/ops/masked_select_ops.h>
260
+ #include <ATen/ops/matmul_ops.h>
261
+ #include <ATen/ops/matrix_H_ops.h>
262
+ #include <ATen/ops/matrix_exp_ops.h>
263
+ #include <ATen/ops/matrix_power_ops.h>
264
+ #include <ATen/ops/max_ops.h>
265
+ #include <ATen/ops/maximum_ops.h>
266
+ #include <ATen/ops/mean_ops.h>
267
+ #include <ATen/ops/median_ops.h>
268
+ #include <ATen/ops/min_ops.h>
269
+ #include <ATen/ops/minimum_ops.h>
270
+ #include <ATen/ops/mm_ops.h>
271
+ #include <ATen/ops/mode_ops.h>
272
+ #include <ATen/ops/moveaxis_ops.h>
273
+ #include <ATen/ops/movedim_ops.h>
274
+ #include <ATen/ops/msort_ops.h>
275
+ #include <ATen/ops/mul_ops.h>
276
+ #include <ATen/ops/multinomial_ops.h>
277
+ #include <ATen/ops/multiply_ops.h>
278
+ #include <ATen/ops/mv_ops.h>
279
+ #include <ATen/ops/mvlgamma_ops.h>
280
+ #include <ATen/ops/nan_to_num_ops.h>
281
+ #include <ATen/ops/nanmean_ops.h>
282
+ #include <ATen/ops/nanmedian_ops.h>
283
+ #include <ATen/ops/nanquantile_ops.h>
284
+ #include <ATen/ops/nansum_ops.h>
285
+ #include <ATen/ops/narrow_copy_ops.h>
286
+ #include <ATen/ops/narrow_ops.h>
287
+ #include <ATen/ops/ne_ops.h>
288
+ #include <ATen/ops/neg_ops.h>
289
+ #include <ATen/ops/negative_ops.h>
290
+ #include <ATen/ops/new_empty_ops.h>
291
+ #include <ATen/ops/new_empty_strided_ops.h>
292
+ #include <ATen/ops/new_full_ops.h>
293
+ #include <ATen/ops/new_ones_ops.h>
294
+ #include <ATen/ops/new_zeros_ops.h>
295
+ #include <ATen/ops/nextafter_ops.h>
296
+ #include <ATen/ops/nonzero_numpy_ops.h>
297
+ #include <ATen/ops/nonzero_ops.h>
298
+ #include <ATen/ops/nonzero_static_ops.h>
299
+ #include <ATen/ops/norm_ops.h>
300
+ #include <ATen/ops/normal_ops.h>
301
+ #include <ATen/ops/not_equal_ops.h>
302
+ #include <ATen/ops/numpy_T_ops.h>
303
+ #include <ATen/ops/or_ops.h>
304
+ #include <ATen/ops/orgqr_ops.h>
305
+ #include <ATen/ops/ormqr_ops.h>
306
+ #include <ATen/ops/outer_ops.h>
307
+ #include <ATen/ops/output_nr_ops.h>
308
+ #include <ATen/ops/permute_ops.h>
309
+ #include <ATen/ops/pin_memory_ops.h>
310
+ #include <ATen/ops/pinverse_ops.h>
311
+ #include <ATen/ops/polygamma_ops.h>
312
+ #include <ATen/ops/positive_ops.h>
313
+ #include <ATen/ops/pow_ops.h>
314
+ #include <ATen/ops/prelu_ops.h>
315
+ #include <ATen/ops/prod_ops.h>
316
+ #include <ATen/ops/put_ops.h>
317
+ #include <ATen/ops/q_per_channel_axis_ops.h>
318
+ #include <ATen/ops/q_per_channel_scales_ops.h>
319
+ #include <ATen/ops/q_per_channel_zero_points_ops.h>
320
+ #include <ATen/ops/q_scale_ops.h>
321
+ #include <ATen/ops/q_zero_point_ops.h>
322
+ #include <ATen/ops/qr_ops.h>
323
+ #include <ATen/ops/qscheme_ops.h>
324
+ #include <ATen/ops/quantile_ops.h>
325
+ #include <ATen/ops/rad2deg_ops.h>
326
+ #include <ATen/ops/random_ops.h>
327
+ #include <ATen/ops/ravel_ops.h>
328
+ #include <ATen/ops/reciprocal_ops.h>
329
+ #include <ATen/ops/record_stream_ops.h>
330
+ #include <ATen/ops/refine_names_ops.h>
331
+ #include <ATen/ops/relu_ops.h>
332
+ #include <ATen/ops/remainder_ops.h>
333
+ #include <ATen/ops/rename_ops.h>
334
+ #include <ATen/ops/renorm_ops.h>
335
+ #include <ATen/ops/repeat_interleave_ops.h>
336
+ #include <ATen/ops/repeat_ops.h>
337
+ #include <ATen/ops/requires_grad_ops.h>
338
+ #include <ATen/ops/reshape_as_ops.h>
339
+ #include <ATen/ops/reshape_ops.h>
340
+ #include <ATen/ops/resize_as_ops.h>
341
+ #include <ATen/ops/resize_as_sparse_ops.h>
342
+ #include <ATen/ops/resize_ops.h>
343
+ #include <ATen/ops/resolve_conj_ops.h>
344
+ #include <ATen/ops/resolve_neg_ops.h>
345
+ #include <ATen/ops/retain_grad_ops.h>
346
+ #include <ATen/ops/retains_grad_ops.h>
347
+ #include <ATen/ops/roll_ops.h>
348
+ #include <ATen/ops/rot90_ops.h>
349
+ #include <ATen/ops/round_ops.h>
350
+ #include <ATen/ops/row_indices_ops.h>
351
+ #include <ATen/ops/rshift_ops.h>
352
+ #include <ATen/ops/rsqrt_ops.h>
353
+ #include <ATen/ops/scatter_add_ops.h>
354
+ #include <ATen/ops/scatter_ops.h>
355
+ #include <ATen/ops/scatter_reduce_ops.h>
356
+ #include <ATen/ops/select_ops.h>
357
+ #include <ATen/ops/select_scatter_ops.h>
358
+ #include <ATen/ops/set_data_ops.h>
359
+ #include <ATen/ops/set_ops.h>
360
+ #include <ATen/ops/sgn_ops.h>
361
+ #include <ATen/ops/sigmoid_ops.h>
362
+ #include <ATen/ops/sign_ops.h>
363
+ #include <ATen/ops/signbit_ops.h>
364
+ #include <ATen/ops/sin_ops.h>
365
+ #include <ATen/ops/sinc_ops.h>
366
+ #include <ATen/ops/sinh_ops.h>
367
+ #include <ATen/ops/size_ops.h>
368
+ #include <ATen/ops/slice_inverse_ops.h>
369
+ #include <ATen/ops/slice_ops.h>
370
+ #include <ATen/ops/slice_scatter_ops.h>
371
+ #include <ATen/ops/slogdet_ops.h>
372
+ #include <ATen/ops/smm_ops.h>
373
+ #include <ATen/ops/softmax_ops.h>
374
+ #include <ATen/ops/sort_ops.h>
375
+ #include <ATen/ops/sparse_dim_ops.h>
376
+ #include <ATen/ops/sparse_mask_ops.h>
377
+ #include <ATen/ops/sparse_resize_and_clear_ops.h>
378
+ #include <ATen/ops/sparse_resize_ops.h>
379
+ #include <ATen/ops/split_ops.h>
380
+ #include <ATen/ops/split_with_sizes_ops.h>
381
+ #include <ATen/ops/sqrt_ops.h>
382
+ #include <ATen/ops/square_ops.h>
383
+ #include <ATen/ops/squeeze_ops.h>
384
+ #include <ATen/ops/sspaddmm_ops.h>
385
+ #include <ATen/ops/std_ops.h>
386
+ #include <ATen/ops/stft_ops.h>
387
+ #include <ATen/ops/stride_ops.h>
388
+ #include <ATen/ops/sub_ops.h>
389
+ #include <ATen/ops/subtract_ops.h>
390
+ #include <ATen/ops/sum_ops.h>
391
+ #include <ATen/ops/sum_to_size_ops.h>
392
+ #include <ATen/ops/svd_ops.h>
393
+ #include <ATen/ops/swapaxes_ops.h>
394
+ #include <ATen/ops/swapdims_ops.h>
395
+ #include <ATen/ops/t_ops.h>
396
+ #include <ATen/ops/take_along_dim_ops.h>
397
+ #include <ATen/ops/take_ops.h>
398
+ #include <ATen/ops/tan_ops.h>
399
+ #include <ATen/ops/tanh_ops.h>
400
+ #include <ATen/ops/tensor_split_ops.h>
401
+ #include <ATen/ops/tile_ops.h>
402
+ #include <ATen/ops/to_dense_ops.h>
403
+ #include <ATen/ops/to_mkldnn_ops.h>
404
+ #include <ATen/ops/to_ops.h>
405
+ #include <ATen/ops/to_padded_tensor_ops.h>
406
+ #include <ATen/ops/to_sparse_bsc_ops.h>
407
+ #include <ATen/ops/to_sparse_bsr_ops.h>
408
+ #include <ATen/ops/to_sparse_csc_ops.h>
409
+ #include <ATen/ops/to_sparse_csr_ops.h>
410
+ #include <ATen/ops/to_sparse_ops.h>
411
+ #include <ATen/ops/topk_ops.h>
412
+ #include <ATen/ops/trace_ops.h>
413
+ #include <ATen/ops/transpose_ops.h>
414
+ #include <ATen/ops/triangular_solve_ops.h>
415
+ #include <ATen/ops/tril_ops.h>
416
+ #include <ATen/ops/triu_ops.h>
417
+ #include <ATen/ops/true_divide_ops.h>
418
+ #include <ATen/ops/trunc_ops.h>
419
+ #include <ATen/ops/type_as_ops.h>
420
+ #include <ATen/ops/unbind_ops.h>
421
+ #include <ATen/ops/unflatten_ops.h>
422
+ #include <ATen/ops/unfold_ops.h>
423
+ #include <ATen/ops/uniform_ops.h>
424
+ #include <ATen/ops/unsafe_chunk_ops.h>
425
+ #include <ATen/ops/unsafe_split_ops.h>
426
+ #include <ATen/ops/unsafe_split_with_sizes_ops.h>
427
+ #include <ATen/ops/unsqueeze_ops.h>
428
+ #include <ATen/ops/values_ops.h>
429
+ #include <ATen/ops/var_ops.h>
430
+ #include <ATen/ops/vdot_ops.h>
431
+ #include <ATen/ops/view_as_ops.h>
432
+ #include <ATen/ops/view_ops.h>
433
+ #include <ATen/ops/vsplit_ops.h>
434
+ #include <ATen/ops/where_ops.h>
435
+ #include <ATen/ops/xlogy_ops.h>
436
+ #include <ATen/ops/xor_ops.h>
437
+ #include <ATen/ops/zero_ops.h>
438
+
439
+ namespace at {
440
+ namespace _ops {
441
+
442
+ } // namespace _ops
443
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/NamedTensor.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/NamedTensor.h>
3
+ #include <ATen/TensorNames.h>
4
+ #include <ATen/WrapDimUtilsMulti.h>
5
+
6
+ #include <ATen/core/DimVector.h>
7
+ #include <ATen/core/Tensor.h>
8
+
9
+ namespace at {
10
+
11
+ using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
12
+
13
+ inline bool has_names(const ITensorListRef& tensors) {
14
+ return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
15
+ return t.has_names();
16
+ });
17
+ }
18
+
19
+ // Converts dim to an positional index. Errors if `dim` cannot be used to
20
+ // refer to any dimension of tensor.
21
+ TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
22
+ TORCH_API std::vector<int64_t> dimnames_to_positions(
23
+ const Tensor& tensor,
24
+ DimnameList dims);
25
+
26
+ // Unifies two DimnameList to produce a third. This is useful for implementing
27
+ // the named inference rule for binary broadcasting operations like add.
28
+ //
29
+ // There are three main constraints:
30
+ // 1) Check matching: Names must match positionally from the right.
31
+ // 2) Check misaligned: If a name `n` is in `names`, then it must appear at
32
+ // the same index from the right in other.
33
+ // 3) The output names are obtained by unifying the names individually from the
34
+ // right.
35
+ TORCH_API std::vector<Dimname> unify_from_right(
36
+ DimnameList names,
37
+ DimnameList other,
38
+ const char* action = "broadcast");
39
+
40
+ [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
41
+ TORCH_CHECK(
42
+ false,
43
+ op_name,
44
+ ": You passed a dimname (string) to this op in place of a dimension "
45
+ "index but it does not yet support this behavior. Please pass a dimension "
46
+ "index to work around this.");
47
+ }
48
+
49
+ // [NOTE] Writing name inference rules
50
+ //
51
+ // Operators that support named tensors are either composed of operations that
52
+ // support named tensors or implement some name inference rule. An op that
53
+ // implements its own name inference rule generally looks like the following:
54
+ //
55
+ // Tensor op(...) {
56
+ // perform_shape_checks(...);
57
+ // # (1)
58
+ // auto maybe_outnames = compute_outnames(...);
59
+ // auto result = [&]() {
60
+ // NoNamesGuard guard;
61
+ // return op_impl(...);
62
+ // }();
63
+ // # (2)
64
+ // propagate_names_if_nonempty(result, maybe_outnames);
65
+ //
66
+ // Each op has (1) a compute outnames step and (2) a propagate names step.
67
+ //
68
+ // compute_outnames is responsible for checking that input names match and
69
+ // determining what the output names should be. It returns either:
70
+ // - {} (if the inputs tensors are all unnamed)
71
+ // - non-empty outnames.
72
+ //
73
+ // propagate_names_if_nonempty propagates the outnames if they exist to the
74
+ // result tensors.
75
+ //
76
+ // The {} case is an optimization; if the user does not use named tensors they
77
+ // pay no perf cost for it.
78
+
79
+ namespace namedinference {
80
+
81
+ const Tensor& propagate_names_if_present_and_nonempty(
82
+ const Tensor& result,
83
+ std::optional<DimnameList> maybe_names,
84
+ bool validate_names = false);
85
+ // Propagates `names` to `result` if `names` is not empty.
86
+ // `names` can be empty; see [NOTE] Writing name inference rules
87
+ // If `names` is not empty, `names.size()` should equal `result.dim()`.
88
+ // When in doubt, use this overload instead of the others.
89
+ TORCH_API const Tensor& propagate_names_if_nonempty(
90
+ const Tensor& result,
91
+ DimnameList maybe_names,
92
+ bool validate_names = false);
93
+
94
+ // Propagates `names` to `result`. Only use this if we are certain that there
95
+ // are names to propagate (that names is not empty).
96
+ TORCH_API const Tensor& propagate_names(
97
+ const Tensor& result,
98
+ DimnameList names,
99
+ bool validate_names = false);
100
+
101
+ // Propagates all names from src to result.
102
+ TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
103
+
104
+ // Propagates all names except for those at the excluded_idxs.
105
+ TORCH_API void propagate_names_except(
106
+ const Tensor& result,
107
+ const Tensor& src,
108
+ IntArrayRef excluded_idxs);
109
+
110
+ // Used for reduction ops that have a `keepdim` arg.
111
+ TORCH_API void propagate_names_for_reduction(
112
+ const Tensor& result,
113
+ const Tensor& src,
114
+ IntArrayRef excluded_idxs,
115
+ bool keepdim);
116
+
117
+ TORCH_API void propagate_names_for_expand(
118
+ const Tensor& result,
119
+ const Tensor& self);
120
+
121
+ TORCH_API std::vector<Dimname> compute_cat_outnames(
122
+ const MaterializedITensorListRef& tensors);
123
+
124
+ TORCH_API std::vector<Dimname> compute_broadcast_outnames(
125
+ const Tensor& self,
126
+ const Tensor& other);
127
+
128
+ TORCH_API std::vector<Dimname> broadcast_to_outnames(
129
+ const Tensor& tensor,
130
+ const Tensor& reference_tensor,
131
+ const char* op_name);
132
+
133
+ TORCH_API std::vector<Dimname> compute_matmul_outnames(
134
+ const Tensor& self,
135
+ const Tensor& other);
136
+
137
+ TORCH_API std::vector<Dimname> compute_cdist_outnames(
138
+ const Tensor& self,
139
+ const Tensor& other);
140
+
141
+ TORCH_API std::vector<Dimname> compute_bmm_outnames(
142
+ const Tensor& result,
143
+ const Tensor& self,
144
+ const Tensor& other);
145
+
146
+ TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
147
+ TORCH_API std::vector<Dimname> compute_squeeze_outnames(
148
+ const Tensor& tensor,
149
+ std::bitset<dim_bitset_size> dims);
150
+
151
+ std::vector<Dimname> compute_diagonal_outnames(
152
+ const Tensor& tensor,
153
+ int64_t dim1,
154
+ int64_t dim2);
155
+
156
+ // TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
157
+
158
+ TORCH_API TensorImpl* propagate_names_if_nonempty(
159
+ TensorImpl* result,
160
+ DimnameList maybe_names,
161
+ bool validate_names = false);
162
+
163
+ TORCH_API TensorImpl* propagate_names(
164
+ TensorImpl* result,
165
+ DimnameList names,
166
+ bool validate_names = false);
167
+
168
+ TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
169
+
170
+ TORCH_API inline void propagate_names(
171
+ const TensorBase& result,
172
+ DimnameList names,
173
+ bool validate_names = false) {
174
+ propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
175
+ }
176
+
177
+ TORCH_API inline void propagate_names_if_nonempty(
178
+ const TensorBase& result,
179
+ DimnameList names,
180
+ bool validate_names = false) {
181
+ propagate_names_if_nonempty(
182
+ result.unsafeGetTensorImpl(), names, validate_names);
183
+ }
184
+
185
+ TORCH_API inline void propagate_names(
186
+ const TensorBase& result,
187
+ const TensorBase& src) {
188
+ propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
189
+ }
190
+
191
+ // result = m1 @ m2 + bias
192
+ TORCH_API std::vector<Dimname> propagate_names_for_addmm(
193
+ const Tensor& m1,
194
+ const Tensor& m2,
195
+ const Tensor& bias);
196
+
197
+ TORCH_API std::vector<Dimname> propagate_names_for_addmv(
198
+ const Tensor& mat,
199
+ const Tensor& vec,
200
+ const Tensor& bias);
201
+
202
+ TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
203
+
204
+ TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
205
+ const Tensor& result,
206
+ const Tensor& self,
207
+ const Tensor& other,
208
+ const Tensor& bias);
209
+
210
+ TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
211
+
212
+ } // namespace namedinference
213
+
214
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunctions.h
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
14
+ #error This change adds a dependency on all pytorch operators, meaning the \
15
+ file will need to be re-compiled every time an operator is changed or added. \
16
+ Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
17
+ and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ #include <c10/core/Scalar.h>
21
+ #include <c10/core/Storage.h>
22
+ #include <c10/core/TensorOptions.h>
23
+ #include <c10/util/Deprecated.h>
24
+ #include <optional>
25
+ #include <c10/core/QScheme.h>
26
+ #include <ATen/core/Reduction.h>
27
+ #include <ATen/core/Tensor.h>
28
+ #include <tuple>
29
+ #include <vector>
30
+
31
+ #include <ATen/ops/_adaptive_avg_pool2d_native.h>
32
+ #include <ATen/ops/_adaptive_avg_pool2d_backward_native.h>
33
+ #include <ATen/ops/_adaptive_avg_pool3d_native.h>
34
+ #include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
35
+ #include <ATen/ops/_add_batch_dim_native.h>
36
+ #include <ATen/ops/_add_relu_native.h>
37
+ #include <ATen/ops/_addmm_activation_native.h>
38
+ #include <ATen/ops/_aminmax_native.h>
39
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
40
+ #include <ATen/ops/_amp_update_scale_native.h>
41
+ #include <ATen/ops/_assert_async_native.h>
42
+ #include <ATen/ops/_assert_scalar_native.h>
43
+ #include <ATen/ops/_assert_tensor_metadata_native.h>
44
+ #include <ATen/ops/_autocast_to_full_precision_native.h>
45
+ #include <ATen/ops/_autocast_to_reduced_precision_native.h>
46
+ #include <ATen/ops/_backward_native.h>
47
+ #include <ATen/ops/_batch_norm_impl_index_native.h>
48
+ #include <ATen/ops/_batch_norm_impl_index_backward_native.h>
49
+ #include <ATen/ops/_batch_norm_no_update_native.h>
50
+ #include <ATen/ops/_batch_norm_with_update_native.h>
51
+ #include <ATen/ops/_cast_Byte_native.h>
52
+ #include <ATen/ops/_cast_Char_native.h>
53
+ #include <ATen/ops/_cast_Double_native.h>
54
+ #include <ATen/ops/_cast_Float_native.h>
55
+ #include <ATen/ops/_cast_Half_native.h>
56
+ #include <ATen/ops/_cast_Int_native.h>
57
+ #include <ATen/ops/_cast_Long_native.h>
58
+ #include <ATen/ops/_cast_Short_native.h>
59
+ #include <ATen/ops/_cdist_backward_native.h>
60
+ #include <ATen/ops/_cdist_forward_native.h>
61
+ #include <ATen/ops/_cholesky_solve_helper_native.h>
62
+ #include <ATen/ops/_choose_qparams_per_tensor_native.h>
63
+ #include <ATen/ops/_chunk_cat_native.h>
64
+ #include <ATen/ops/_coalesce_native.h>
65
+ #include <ATen/ops/_coalesced_native.h>
66
+ #include <ATen/ops/_compute_linear_combination_native.h>
67
+ #include <ATen/ops/_conj_native.h>
68
+ #include <ATen/ops/_conj_copy_native.h>
69
+ #include <ATen/ops/_conj_physical_native.h>
70
+ #include <ATen/ops/_conv_depthwise2d_native.h>
71
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
72
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
73
+ #include <ATen/ops/_convert_weight_to_int4pack_native.h>
74
+ #include <ATen/ops/_convolution_native.h>
75
+ #include <ATen/ops/_convolution_double_backward_native.h>
76
+ #include <ATen/ops/_convolution_mode_native.h>
77
+ #include <ATen/ops/_copy_from_native.h>
78
+ #include <ATen/ops/_copy_from_and_resize_native.h>
79
+ #include <ATen/ops/_cslt_compress_native.h>
80
+ #include <ATen/ops/_cslt_sparse_mm_native.h>
81
+ #include <ATen/ops/_cslt_sparse_mm_search_native.h>
82
+ #include <ATen/ops/_ctc_loss_native.h>
83
+ #include <ATen/ops/_ctc_loss_backward_native.h>
84
+ #include <ATen/ops/_cudnn_ctc_loss_native.h>
85
+ #include <ATen/ops/_cudnn_init_dropout_state_native.h>
86
+ #include <ATen/ops/_cudnn_rnn_native.h>
87
+ #include <ATen/ops/_cudnn_rnn_backward_native.h>
88
+ #include <ATen/ops/_cudnn_rnn_flatten_weight_native.h>
89
+ #include <ATen/ops/_cufft_clear_plan_cache_native.h>
90
+ #include <ATen/ops/_cufft_get_plan_cache_max_size_native.h>
91
+ #include <ATen/ops/_cufft_get_plan_cache_size_native.h>
92
+ #include <ATen/ops/_cufft_set_plan_cache_max_size_native.h>
93
+ #include <ATen/ops/_cummax_helper_native.h>
94
+ #include <ATen/ops/_cummin_helper_native.h>
95
+ #include <ATen/ops/_debug_has_internal_overlap_native.h>
96
+ #include <ATen/ops/_dimI_native.h>
97
+ #include <ATen/ops/_dimV_native.h>
98
+ #include <ATen/ops/_dim_arange_native.h>
99
+ #include <ATen/ops/_dirichlet_grad_native.h>
100
+ #include <ATen/ops/_efficient_attention_backward_native.h>
101
+ #include <ATen/ops/_efficient_attention_forward_native.h>
102
+ #include <ATen/ops/_efficientzerotensor_native.h>
103
+ #include <ATen/ops/_embedding_bag_native.h>
104
+ #include <ATen/ops/_embedding_bag_backward_native.h>
105
+ #include <ATen/ops/_embedding_bag_dense_backward_native.h>
106
+ #include <ATen/ops/_embedding_bag_forward_only_native.h>
107
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
108
+ #include <ATen/ops/_embedding_bag_sparse_backward_native.h>
109
+ #include <ATen/ops/_empty_affine_quantized_native.h>
110
+ #include <ATen/ops/_empty_per_channel_affine_quantized_native.h>
111
+ #include <ATen/ops/_euclidean_dist_native.h>
112
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_native.h>
113
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_native.h>
114
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_native.h>
115
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_native.h>
116
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_native.h>
117
+ #include <ATen/ops/_fft_c2c_native.h>
118
+ #include <ATen/ops/_fft_c2r_native.h>
119
+ #include <ATen/ops/_fft_r2c_native.h>
120
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
121
+ #include <ATen/ops/_flash_attention_backward_native.h>
122
+ #include <ATen/ops/_flash_attention_forward_native.h>
123
+ #include <ATen/ops/_foobar_native.h>
124
+ #include <ATen/ops/_foreach_abs_native.h>
125
+ #include <ATen/ops/_foreach_acos_native.h>
126
+ #include <ATen/ops/_foreach_add_native.h>
127
+ #include <ATen/ops/_foreach_addcdiv_native.h>
128
+ #include <ATen/ops/_foreach_addcmul_native.h>
129
+ #include <ATen/ops/_foreach_asin_native.h>
130
+ #include <ATen/ops/_foreach_atan_native.h>
131
+ #include <ATen/ops/_foreach_ceil_native.h>
132
+ #include <ATen/ops/_foreach_clamp_max_native.h>
133
+ #include <ATen/ops/_foreach_clamp_min_native.h>
134
+ #include <ATen/ops/_foreach_copy_native.h>
135
+ #include <ATen/ops/_foreach_cos_native.h>
136
+ #include <ATen/ops/_foreach_cosh_native.h>
137
+ #include <ATen/ops/_foreach_div_native.h>
138
+ #include <ATen/ops/_foreach_erf_native.h>
139
+ #include <ATen/ops/_foreach_erfc_native.h>
140
+ #include <ATen/ops/_foreach_exp_native.h>
141
+ #include <ATen/ops/_foreach_expm1_native.h>
142
+ #include <ATen/ops/_foreach_floor_native.h>
143
+ #include <ATen/ops/_foreach_frac_native.h>
144
+ #include <ATen/ops/_foreach_lerp_native.h>
145
+ #include <ATen/ops/_foreach_lgamma_native.h>
146
+ #include <ATen/ops/_foreach_log_native.h>
147
+ #include <ATen/ops/_foreach_log10_native.h>
148
+ #include <ATen/ops/_foreach_log1p_native.h>
149
+ #include <ATen/ops/_foreach_log2_native.h>
150
+ #include <ATen/ops/_foreach_max_native.h>
151
+ #include <ATen/ops/_foreach_maximum_native.h>
152
+ #include <ATen/ops/_foreach_minimum_native.h>
153
+ #include <ATen/ops/_foreach_mul_native.h>
154
+ #include <ATen/ops/_foreach_neg_native.h>
155
+ #include <ATen/ops/_foreach_norm_native.h>
156
+ #include <ATen/ops/_foreach_pow_native.h>
157
+ #include <ATen/ops/_foreach_reciprocal_native.h>
158
+ #include <ATen/ops/_foreach_round_native.h>
159
+ #include <ATen/ops/_foreach_sigmoid_native.h>
160
+ #include <ATen/ops/_foreach_sign_native.h>
161
+ #include <ATen/ops/_foreach_sin_native.h>
162
+ #include <ATen/ops/_foreach_sinh_native.h>
163
+ #include <ATen/ops/_foreach_sqrt_native.h>
164
+ #include <ATen/ops/_foreach_sub_native.h>
165
+ #include <ATen/ops/_foreach_tan_native.h>
166
+ #include <ATen/ops/_foreach_tanh_native.h>
167
+ #include <ATen/ops/_foreach_trunc_native.h>
168
+ #include <ATen/ops/_foreach_zero_native.h>
169
+ #include <ATen/ops/_functional_assert_async_native.h>
170
+ #include <ATen/ops/_functional_assert_scalar_native.h>
171
+ #include <ATen/ops/_functional_sym_constrain_range_native.h>
172
+ #include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
173
+ #include <ATen/ops/_fused_adagrad_native.h>
174
+ #include <ATen/ops/_fused_adam_native.h>
175
+ #include <ATen/ops/_fused_adamw_native.h>
176
+ #include <ATen/ops/_fused_dropout_native.h>
177
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper_native.h>
178
+ #include <ATen/ops/_fused_sdp_choice_native.h>
179
+ #include <ATen/ops/_fused_sgd_native.h>
180
+ #include <ATen/ops/_fw_primal_native.h>
181
+ #include <ATen/ops/_fw_primal_copy_native.h>
182
+ #include <ATen/ops/_gather_sparse_backward_native.h>
183
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_native.h>
184
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h>
185
+ #include <ATen/ops/_has_compatible_shallow_copy_type_native.h>
186
+ #include <ATen/ops/_has_same_storage_numel_native.h>
187
+ #include <ATen/ops/_histogramdd_bin_edges_native.h>
188
+ #include <ATen/ops/_histogramdd_from_bin_cts_native.h>
189
+ #include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
190
+ #include <ATen/ops/_index_put_impl_native.h>
191
+ #include <ATen/ops/_indices_native.h>
192
+ #include <ATen/ops/_indices_copy_native.h>
193
+ #include <ATen/ops/_int_mm_native.h>
194
+ #include <ATen/ops/_is_all_true_native.h>
195
+ #include <ATen/ops/_is_any_true_native.h>
196
+ #include <ATen/ops/_is_zerotensor_native.h>
197
+ #include <ATen/ops/_jagged_to_padded_dense_forward_native.h>
198
+ #include <ATen/ops/_lazy_clone_native.h>
199
+ #include <ATen/ops/_linalg_check_errors_native.h>
200
+ #include <ATen/ops/_linalg_det_native.h>
201
+ #include <ATen/ops/_linalg_eigh_native.h>
202
+ #include <ATen/ops/_linalg_eigvals_native.h>
203
+ #include <ATen/ops/_linalg_slogdet_native.h>
204
+ #include <ATen/ops/_linalg_solve_ex_native.h>
205
+ #include <ATen/ops/_linalg_svd_native.h>
206
+ #include <ATen/ops/_local_scalar_dense_native.h>
207
+ #include <ATen/ops/_log_softmax_native.h>
208
+ #include <ATen/ops/_log_softmax_backward_data_native.h>
209
+ #include <ATen/ops/_logcumsumexp_native.h>
210
+ #include <ATen/ops/_lstm_mps_native.h>
211
+ #include <ATen/ops/_lu_with_info_native.h>
212
+ #include <ATen/ops/_make_dep_token_native.h>
213
+ #include <ATen/ops/_make_dual_native.h>
214
+ #include <ATen/ops/_make_dual_copy_native.h>
215
+ #include <ATen/ops/_make_per_channel_quantized_tensor_native.h>
216
+ #include <ATen/ops/_make_per_tensor_quantized_tensor_native.h>
217
+ #include <ATen/ops/_masked_scale_native.h>
218
+ #include <ATen/ops/_masked_softmax_native.h>
219
+ #include <ATen/ops/_masked_softmax_backward_native.h>
220
+ #include <ATen/ops/_mixed_dtypes_linear_native.h>
221
+ #include <ATen/ops/_mkldnn_reshape_native.h>
222
+ #include <ATen/ops/_mkldnn_transpose_native.h>
223
+ #include <ATen/ops/_mps_convolution_native.h>
224
+ #include <ATen/ops/_mps_convolution_transpose_native.h>
225
+ #include <ATen/ops/_native_batch_norm_legit_native.h>
226
+ #include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
227
+ #include <ATen/ops/_native_multi_head_attention_native.h>
228
+ #include <ATen/ops/_neg_view_native.h>
229
+ #include <ATen/ops/_neg_view_copy_native.h>
230
+ #include <ATen/ops/_nested_compute_contiguous_strides_offsets_native.h>
231
+ #include <ATen/ops/_nested_from_padded_native.h>
232
+ #include <ATen/ops/_nested_from_padded_and_nested_example_native.h>
233
+ #include <ATen/ops/_nested_get_jagged_dummy_native.h>
234
+ #include <ATen/ops/_nested_get_lengths_native.h>
235
+ #include <ATen/ops/_nested_get_max_seqlen_native.h>
236
+ #include <ATen/ops/_nested_get_min_seqlen_native.h>
237
+ #include <ATen/ops/_nested_get_offsets_native.h>
238
+ #include <ATen/ops/_nested_get_ragged_idx_native.h>
239
+ #include <ATen/ops/_nested_get_values_native.h>
240
+ #include <ATen/ops/_nested_get_values_copy_native.h>
241
+ #include <ATen/ops/_nested_select_backward_native.h>
242
+ #include <ATen/ops/_nested_sum_backward_native.h>
243
+ #include <ATen/ops/_nested_tensor_from_mask_native.h>
244
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned_native.h>
245
+ #include <ATen/ops/_nested_tensor_from_tensor_list_native.h>
246
+ #include <ATen/ops/_nested_tensor_size_native.h>
247
+ #include <ATen/ops/_nested_tensor_softmax_with_shape_native.h>
248
+ #include <ATen/ops/_nested_tensor_storage_offsets_native.h>
249
+ #include <ATen/ops/_nested_tensor_strides_native.h>
250
+ #include <ATen/ops/_nested_view_from_buffer_native.h>
251
+ #include <ATen/ops/_nested_view_from_buffer_copy_native.h>
252
+ #include <ATen/ops/_nested_view_from_jagged_native.h>
253
+ #include <ATen/ops/_nested_view_from_jagged_copy_native.h>
254
+ #include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
255
+ #include <ATen/ops/_nnpack_available_native.h>
256
+ #include <ATen/ops/_nnpack_spatial_convolution_native.h>
257
+ #include <ATen/ops/_nnz_native.h>
258
+ #include <ATen/ops/_pack_padded_sequence_native.h>
259
+ #include <ATen/ops/_pack_padded_sequence_backward_native.h>
260
+ #include <ATen/ops/_pad_circular_native.h>
261
+ #include <ATen/ops/_pad_enum_native.h>
262
+ #include <ATen/ops/_pad_packed_sequence_native.h>
263
+ #include <ATen/ops/_padded_dense_to_jagged_forward_native.h>
264
+ #include <ATen/ops/_pdist_backward_native.h>
265
+ #include <ATen/ops/_pdist_forward_native.h>
266
+ #include <ATen/ops/_pin_memory_native.h>
267
+ #include <ATen/ops/_prelu_kernel_native.h>
268
+ #include <ATen/ops/_prelu_kernel_backward_native.h>
269
+ #include <ATen/ops/_print_native.h>
270
+ #include <ATen/ops/_propagate_xla_data_native.h>
271
+ #include <ATen/ops/_remove_batch_dim_native.h>
272
+ #include <ATen/ops/_reshape_alias_native.h>
273
+ #include <ATen/ops/_reshape_alias_copy_native.h>
274
+ #include <ATen/ops/_reshape_copy_native.h>
275
+ #include <ATen/ops/_reshape_from_tensor_native.h>
276
+ #include <ATen/ops/_resize_output_native.h>
277
+ #include <ATen/ops/_rowwise_prune_native.h>
278
+ #include <ATen/ops/_safe_softmax_native.h>
279
+ #include <ATen/ops/_sample_dirichlet_native.h>
280
+ #include <ATen/ops/_saturate_weight_to_fp16_native.h>
281
+ #include <ATen/ops/_scaled_dot_product_attention_math_native.h>
282
+ #include <ATen/ops/_scaled_dot_product_attention_math_for_mps_native.h>
283
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_native.h>
284
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_native.h>
285
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_native.h>
286
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward_native.h>
287
+ #include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
288
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
289
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
290
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
291
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_native.h>
292
+ #include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h>
293
+ #include <ATen/ops/_scaled_mm_native.h>
294
+ #include <ATen/ops/_segment_reduce_backward_native.h>
295
+ #include <ATen/ops/_shape_as_tensor_native.h>
296
+ #include <ATen/ops/_slow_conv2d_backward_native.h>
297
+ #include <ATen/ops/_slow_conv2d_forward_native.h>
298
+ #include <ATen/ops/_sobol_engine_draw_native.h>
299
+ #include <ATen/ops/_sobol_engine_ff_native.h>
300
+ #include <ATen/ops/_sobol_engine_initialize_state_native.h>
301
+ #include <ATen/ops/_sobol_engine_scramble_native.h>
302
+ #include <ATen/ops/_softmax_native.h>
303
+ #include <ATen/ops/_softmax_backward_data_native.h>
304
+ #include <ATen/ops/_sparse_addmm_native.h>
305
+ #include <ATen/ops/_sparse_broadcast_to_native.h>
306
+ #include <ATen/ops/_sparse_broadcast_to_copy_native.h>
307
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
308
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
309
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
310
+ #include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
311
+ #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
312
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
313
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h>
314
+ #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
315
+ #include <ATen/ops/_sparse_csr_prod_native.h>
316
+ #include <ATen/ops/_sparse_csr_sum_native.h>
317
+ #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
318
+ #include <ATen/ops/_sparse_log_softmax_native.h>
319
+ #include <ATen/ops/_sparse_log_softmax_backward_data_native.h>
320
+ #include <ATen/ops/_sparse_mask_projection_native.h>
321
+ #include <ATen/ops/_sparse_mm_native.h>
322
+ #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
323
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
324
+ #include <ATen/ops/_sparse_semi_structured_addmm_native.h>
325
+ #include <ATen/ops/_sparse_semi_structured_apply_native.h>
326
+ #include <ATen/ops/_sparse_semi_structured_apply_dense_native.h>
327
+ #include <ATen/ops/_sparse_semi_structured_linear_native.h>
328
+ #include <ATen/ops/_sparse_semi_structured_mm_native.h>
329
+ #include <ATen/ops/_sparse_semi_structured_tile_native.h>
330
+ #include <ATen/ops/_sparse_softmax_native.h>
331
+ #include <ATen/ops/_sparse_softmax_backward_data_native.h>
332
+ #include <ATen/ops/_sparse_sparse_matmul_native.h>
333
+ #include <ATen/ops/_sparse_sum_native.h>
334
+ #include <ATen/ops/_sparse_sum_backward_native.h>
335
+ #include <ATen/ops/_spdiags_native.h>
336
+ #include <ATen/ops/_spsolve_native.h>
337
+ #include <ATen/ops/_stack_native.h>
338
+ #include <ATen/ops/_standard_gamma_native.h>
339
+ #include <ATen/ops/_standard_gamma_grad_native.h>
340
+ #include <ATen/ops/_test_ambiguous_defaults_native.h>
341
+ #include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
342
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
343
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_native.h>
344
+ #include <ATen/ops/_test_check_tensor_native.h>
345
+ #include <ATen/ops/_test_functorch_fallback_native.h>
346
+ #include <ATen/ops/_test_optional_filled_intlist_native.h>
347
+ #include <ATen/ops/_test_optional_floatlist_native.h>
348
+ #include <ATen/ops/_test_optional_intlist_native.h>
349
+ #include <ATen/ops/_test_parallel_materialize_native.h>
350
+ #include <ATen/ops/_test_serialization_subcmul_native.h>
351
+ #include <ATen/ops/_test_string_default_native.h>
352
+ #include <ATen/ops/_test_warn_in_autograd_native.h>
353
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward_native.h>
354
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward_native.h>
355
+ #include <ATen/ops/_thnn_fused_gru_cell_native.h>
356
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_native.h>
357
+ #include <ATen/ops/_thnn_fused_lstm_cell_native.h>
358
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_native.h>
359
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h>
360
+ #include <ATen/ops/_to_copy_native.h>
361
+ #include <ATen/ops/_to_cpu_native.h>
362
+ #include <ATen/ops/_to_dense_native.h>
363
+ #include <ATen/ops/_to_sparse_native.h>
364
+ #include <ATen/ops/_to_sparse_bsc_native.h>
365
+ #include <ATen/ops/_to_sparse_bsr_native.h>
366
+ #include <ATen/ops/_to_sparse_csc_native.h>
367
+ #include <ATen/ops/_to_sparse_csr_native.h>
368
+ #include <ATen/ops/_to_sparse_semi_structured_native.h>
369
+ #include <ATen/ops/_transform_bias_rescale_qkv_native.h>
370
+ #include <ATen/ops/_transformer_encoder_layer_fwd_native.h>
371
+ #include <ATen/ops/_trilinear_native.h>
372
+ #include <ATen/ops/_triton_multi_head_attention_native.h>
373
+ #include <ATen/ops/_triton_scaled_dot_attention_native.h>
374
+ #include <ATen/ops/_unique_native.h>
375
+ #include <ATen/ops/_unique2_native.h>
376
+ #include <ATen/ops/_unpack_dual_native.h>
377
+ #include <ATen/ops/_unsafe_index_native.h>
378
+ #include <ATen/ops/_unsafe_index_put_native.h>
379
+ #include <ATen/ops/_unsafe_masked_index_native.h>
380
+ #include <ATen/ops/_unsafe_masked_index_put_accumulate_native.h>
381
+ #include <ATen/ops/_unsafe_view_native.h>
382
+ #include <ATen/ops/_upsample_bicubic2d_aa_native.h>
383
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_native.h>
384
+ #include <ATen/ops/_upsample_bilinear2d_aa_native.h>
385
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
386
+ #include <ATen/ops/_upsample_nearest_exact1d_native.h>
387
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
388
+ #include <ATen/ops/_upsample_nearest_exact2d_native.h>
389
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
390
+ #include <ATen/ops/_upsample_nearest_exact3d_native.h>
391
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
392
+ #include <ATen/ops/_use_cudnn_ctc_loss_native.h>
393
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight_native.h>
394
+ #include <ATen/ops/_validate_compressed_sparse_indices_native.h>
395
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
396
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
397
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
398
+ #include <ATen/ops/_validate_sparse_coo_tensor_args_native.h>
399
+ #include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
400
+ #include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
401
+ #include <ATen/ops/_values_native.h>
402
+ #include <ATen/ops/_values_copy_native.h>
403
+ #include <ATen/ops/_version_native.h>
404
+ #include <ATen/ops/_weight_int4pack_mm_native.h>
405
+ #include <ATen/ops/_weight_int8pack_mm_native.h>
406
+ #include <ATen/ops/_weight_norm_native.h>
407
+ #include <ATen/ops/_weight_norm_differentiable_backward_native.h>
408
+ #include <ATen/ops/_weight_norm_interface_native.h>
409
+ #include <ATen/ops/_weight_norm_interface_backward_native.h>
410
+ #include <ATen/ops/_wrapped_linear_prepack_native.h>
411
+ #include <ATen/ops/_wrapped_quantized_linear_prepacked_native.h>
412
+ #include <ATen/ops/abs_native.h>
413
+ #include <ATen/ops/absolute_native.h>
414
+ #include <ATen/ops/acos_native.h>
415
+ #include <ATen/ops/acosh_native.h>
416
+ #include <ATen/ops/adaptive_avg_pool1d_native.h>
417
+ #include <ATen/ops/adaptive_avg_pool2d_native.h>
418
+ #include <ATen/ops/adaptive_avg_pool3d_native.h>
419
+ #include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
420
+ #include <ATen/ops/adaptive_max_pool1d_native.h>
421
+ #include <ATen/ops/adaptive_max_pool2d_native.h>
422
+ #include <ATen/ops/adaptive_max_pool2d_backward_native.h>
423
+ #include <ATen/ops/adaptive_max_pool3d_native.h>
424
+ #include <ATen/ops/adaptive_max_pool3d_backward_native.h>
425
+ #include <ATen/ops/add_native.h>
426
+ #include <ATen/ops/addbmm_native.h>
427
+ #include <ATen/ops/addcdiv_native.h>
428
+ #include <ATen/ops/addcmul_native.h>
429
+ #include <ATen/ops/addmm_native.h>
430
+ #include <ATen/ops/addmv_native.h>
431
+ #include <ATen/ops/addr_native.h>
432
+ #include <ATen/ops/adjoint_native.h>
433
+ #include <ATen/ops/affine_grid_generator_native.h>
434
+ #include <ATen/ops/affine_grid_generator_backward_native.h>
435
+ #include <ATen/ops/alias_native.h>
436
+ #include <ATen/ops/alias_copy_native.h>
437
+ #include <ATen/ops/align_as_native.h>
438
+ #include <ATen/ops/align_tensors_native.h>
439
+ #include <ATen/ops/align_to_native.h>
440
+ #include <ATen/ops/all_native.h>
441
+ #include <ATen/ops/allclose_native.h>
442
+ #include <ATen/ops/alpha_dropout_native.h>
443
+ #include <ATen/ops/amax_native.h>
444
+ #include <ATen/ops/amin_native.h>
445
+ #include <ATen/ops/aminmax_native.h>
446
+ #include <ATen/ops/and_native.h>
447
+ #include <ATen/ops/angle_native.h>
448
+ #include <ATen/ops/any_native.h>
449
+ #include <ATen/ops/arange_native.h>
450
+ #include <ATen/ops/arccos_native.h>
451
+ #include <ATen/ops/arccosh_native.h>
452
+ #include <ATen/ops/arcsin_native.h>
453
+ #include <ATen/ops/arcsinh_native.h>
454
+ #include <ATen/ops/arctan_native.h>
455
+ #include <ATen/ops/arctan2_native.h>
456
+ #include <ATen/ops/arctanh_native.h>
457
+ #include <ATen/ops/argmax_native.h>
458
+ #include <ATen/ops/argmin_native.h>
459
+ #include <ATen/ops/argsort_native.h>
460
+ #include <ATen/ops/argwhere_native.h>
461
+ #include <ATen/ops/as_strided_native.h>
462
+ #include <ATen/ops/as_strided_copy_native.h>
463
+ #include <ATen/ops/as_strided_scatter_native.h>
464
+ #include <ATen/ops/asin_native.h>
465
+ #include <ATen/ops/asinh_native.h>
466
+ #include <ATen/ops/atan_native.h>
467
+ #include <ATen/ops/atan2_native.h>
468
+ #include <ATen/ops/atanh_native.h>
469
+ #include <ATen/ops/atleast_1d_native.h>
470
+ #include <ATen/ops/atleast_2d_native.h>
471
+ #include <ATen/ops/atleast_3d_native.h>
472
+ #include <ATen/ops/avg_pool1d_native.h>
473
+ #include <ATen/ops/avg_pool2d_native.h>
474
+ #include <ATen/ops/avg_pool2d_backward_native.h>
475
+ #include <ATen/ops/avg_pool3d_native.h>
476
+ #include <ATen/ops/avg_pool3d_backward_native.h>
477
+ #include <ATen/ops/baddbmm_native.h>
478
+ #include <ATen/ops/bartlett_window_native.h>
479
+ #include <ATen/ops/batch_norm_native.h>
480
+ #include <ATen/ops/batch_norm_backward_native.h>
481
+ #include <ATen/ops/batch_norm_backward_elemt_native.h>
482
+ #include <ATen/ops/batch_norm_backward_reduce_native.h>
483
+ #include <ATen/ops/batch_norm_elemt_native.h>
484
+ #include <ATen/ops/batch_norm_gather_stats_native.h>
485
+ #include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
486
+ #include <ATen/ops/batch_norm_stats_native.h>
487
+ #include <ATen/ops/batch_norm_update_stats_native.h>
488
+ #include <ATen/ops/bernoulli_native.h>
489
+ #include <ATen/ops/bilinear_native.h>
490
+ #include <ATen/ops/binary_cross_entropy_native.h>
491
+ #include <ATen/ops/binary_cross_entropy_backward_native.h>
492
+ #include <ATen/ops/binary_cross_entropy_with_logits_native.h>
493
+ #include <ATen/ops/bincount_native.h>
494
+ #include <ATen/ops/binomial_native.h>
495
+ #include <ATen/ops/bitwise_and_native.h>
496
+ #include <ATen/ops/bitwise_left_shift_native.h>
497
+ #include <ATen/ops/bitwise_not_native.h>
498
+ #include <ATen/ops/bitwise_or_native.h>
499
+ #include <ATen/ops/bitwise_right_shift_native.h>
500
+ #include <ATen/ops/bitwise_xor_native.h>
501
+ #include <ATen/ops/blackman_window_native.h>
502
+ #include <ATen/ops/block_diag_native.h>
503
+ #include <ATen/ops/bmm_native.h>
504
+ #include <ATen/ops/broadcast_tensors_native.h>
505
+ #include <ATen/ops/broadcast_to_native.h>
506
+ #include <ATen/ops/bucketize_native.h>
507
+ #include <ATen/ops/can_cast_native.h>
508
+ #include <ATen/ops/cartesian_prod_native.h>
509
+ #include <ATen/ops/cat_native.h>
510
+ #include <ATen/ops/cauchy_native.h>
511
+ #include <ATen/ops/ccol_indices_native.h>
512
+ #include <ATen/ops/ccol_indices_copy_native.h>
513
+ #include <ATen/ops/cdist_native.h>
514
+ #include <ATen/ops/ceil_native.h>
515
+ #include <ATen/ops/celu_native.h>
516
+ #include <ATen/ops/chain_matmul_native.h>
517
+ #include <ATen/ops/chalf_native.h>
518
+ #include <ATen/ops/channel_shuffle_native.h>
519
+ #include <ATen/ops/cholesky_native.h>
520
+ #include <ATen/ops/cholesky_inverse_native.h>
521
+ #include <ATen/ops/cholesky_solve_native.h>
522
+ #include <ATen/ops/choose_qparams_optimized_native.h>
523
+ #include <ATen/ops/chunk_native.h>
524
+ #include <ATen/ops/clamp_native.h>
525
+ #include <ATen/ops/clamp_max_native.h>
526
+ #include <ATen/ops/clamp_min_native.h>
527
+ #include <ATen/ops/clip_native.h>
528
+ #include <ATen/ops/clone_native.h>
529
+ #include <ATen/ops/coalesce_native.h>
530
+ #include <ATen/ops/col2im_native.h>
531
+ #include <ATen/ops/col_indices_native.h>
532
+ #include <ATen/ops/col_indices_copy_native.h>
533
+ #include <ATen/ops/column_stack_native.h>
534
+ #include <ATen/ops/combinations_native.h>
535
+ #include <ATen/ops/complex_native.h>
536
+ #include <ATen/ops/concat_native.h>
537
+ #include <ATen/ops/concatenate_native.h>
538
+ #include <ATen/ops/conj_native.h>
539
+ #include <ATen/ops/conj_physical_native.h>
540
+ #include <ATen/ops/constant_pad_nd_native.h>
541
+ #include <ATen/ops/contiguous_native.h>
542
+ #include <ATen/ops/conv1d_native.h>
543
+ #include <ATen/ops/conv2d_native.h>
544
+ #include <ATen/ops/conv3d_native.h>
545
+ #include <ATen/ops/conv_depthwise3d_native.h>
546
+ #include <ATen/ops/conv_tbc_native.h>
547
+ #include <ATen/ops/conv_tbc_backward_native.h>
548
+ #include <ATen/ops/conv_transpose1d_native.h>
549
+ #include <ATen/ops/conv_transpose2d_native.h>
550
+ #include <ATen/ops/conv_transpose3d_native.h>
551
+ #include <ATen/ops/convolution_native.h>
552
+ #include <ATen/ops/convolution_backward_native.h>
553
+ #include <ATen/ops/convolution_backward_overrideable_native.h>
554
+ #include <ATen/ops/convolution_overrideable_native.h>
555
+ #include <ATen/ops/copy_native.h>
556
+ #include <ATen/ops/copy_sparse_to_sparse_native.h>
557
+ #include <ATen/ops/copysign_native.h>
558
+ #include <ATen/ops/corrcoef_native.h>
559
+ #include <ATen/ops/cos_native.h>
560
+ #include <ATen/ops/cosh_native.h>
561
+ #include <ATen/ops/cosine_embedding_loss_native.h>
562
+ #include <ATen/ops/cosine_similarity_native.h>
563
+ #include <ATen/ops/count_nonzero_native.h>
564
+ #include <ATen/ops/cov_native.h>
565
+ #include <ATen/ops/cross_native.h>
566
+ #include <ATen/ops/cross_entropy_loss_native.h>
567
+ #include <ATen/ops/crow_indices_native.h>
568
+ #include <ATen/ops/crow_indices_copy_native.h>
569
+ #include <ATen/ops/ctc_loss_native.h>
570
+ #include <ATen/ops/cudnn_affine_grid_generator_native.h>
571
+ #include <ATen/ops/cudnn_affine_grid_generator_backward_native.h>
572
+ #include <ATen/ops/cudnn_batch_norm_native.h>
573
+ #include <ATen/ops/cudnn_batch_norm_backward_native.h>
574
+ #include <ATen/ops/cudnn_convolution_native.h>
575
+ #include <ATen/ops/cudnn_convolution_add_relu_native.h>
576
+ #include <ATen/ops/cudnn_convolution_relu_native.h>
577
+ #include <ATen/ops/cudnn_convolution_transpose_native.h>
578
+ #include <ATen/ops/cudnn_grid_sampler_native.h>
579
+ #include <ATen/ops/cudnn_grid_sampler_backward_native.h>
580
+ #include <ATen/ops/cudnn_is_acceptable_native.h>
581
+ #include <ATen/ops/cummax_native.h>
582
+ #include <ATen/ops/cummaxmin_backward_native.h>
583
+ #include <ATen/ops/cummin_native.h>
584
+ #include <ATen/ops/cumprod_native.h>
585
+ #include <ATen/ops/cumprod_backward_native.h>
586
+ #include <ATen/ops/cumsum_native.h>
587
+ #include <ATen/ops/cumulative_trapezoid_native.h>
588
+ #include <ATen/ops/data_native.h>
589
+ #include <ATen/ops/deg2rad_native.h>
590
+ #include <ATen/ops/dense_dim_native.h>
591
+ #include <ATen/ops/dequantize_native.h>
592
+ #include <ATen/ops/det_native.h>
593
+ #include <ATen/ops/detach_native.h>
594
+ #include <ATen/ops/detach_copy_native.h>
595
+ #include <ATen/ops/diag_native.h>
596
+ #include <ATen/ops/diag_embed_native.h>
597
+ #include <ATen/ops/diagflat_native.h>
598
+ #include <ATen/ops/diagonal_native.h>
599
+ #include <ATen/ops/diagonal_backward_native.h>
600
+ #include <ATen/ops/diagonal_copy_native.h>
601
+ #include <ATen/ops/diagonal_scatter_native.h>
602
+ #include <ATen/ops/diff_native.h>
603
+ #include <ATen/ops/digamma_native.h>
604
+ #include <ATen/ops/dist_native.h>
605
+ #include <ATen/ops/div_native.h>
606
+ #include <ATen/ops/divide_native.h>
607
+ #include <ATen/ops/dot_native.h>
608
+ #include <ATen/ops/dropout_native.h>
609
+ #include <ATen/ops/dsplit_native.h>
610
+ #include <ATen/ops/dstack_native.h>
611
+ #include <ATen/ops/einsum_native.h>
612
+ #include <ATen/ops/elu_native.h>
613
+ #include <ATen/ops/elu_backward_native.h>
614
+ #include <ATen/ops/embedding_native.h>
615
+ #include <ATen/ops/embedding_backward_native.h>
616
+ #include <ATen/ops/embedding_bag_native.h>
617
+ #include <ATen/ops/embedding_dense_backward_native.h>
618
+ #include <ATen/ops/embedding_renorm_native.h>
619
+ #include <ATen/ops/embedding_sparse_backward_native.h>
620
+ #include <ATen/ops/empty_native.h>
621
+ #include <ATen/ops/empty_like_native.h>
622
+ #include <ATen/ops/empty_permuted_native.h>
623
+ #include <ATen/ops/empty_quantized_native.h>
624
+ #include <ATen/ops/empty_strided_native.h>
625
+ #include <ATen/ops/eq_native.h>
626
+ #include <ATen/ops/equal_native.h>
627
+ #include <ATen/ops/erf_native.h>
628
+ #include <ATen/ops/erfc_native.h>
629
+ #include <ATen/ops/erfinv_native.h>
630
+ #include <ATen/ops/exp_native.h>
631
+ #include <ATen/ops/exp2_native.h>
632
+ #include <ATen/ops/expand_native.h>
633
+ #include <ATen/ops/expand_as_native.h>
634
+ #include <ATen/ops/expand_copy_native.h>
635
+ #include <ATen/ops/expm1_native.h>
636
+ #include <ATen/ops/exponential_native.h>
637
+ #include <ATen/ops/eye_native.h>
638
+ #include <ATen/ops/fake_quantize_per_channel_affine_native.h>
639
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_native.h>
640
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_native.h>
641
+ #include <ATen/ops/fake_quantize_per_tensor_affine_native.h>
642
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_native.h>
643
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_native.h>
644
+ #include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
645
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
646
+ #include <ATen/ops/fbgemm_linear_int8_weight_native.h>
647
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
648
+ #include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
649
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
650
+ #include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
651
+ #include <ATen/ops/feature_alpha_dropout_native.h>
652
+ #include <ATen/ops/feature_dropout_native.h>
653
+ #include <ATen/ops/fft_fft_native.h>
654
+ #include <ATen/ops/fft_fft2_native.h>
655
+ #include <ATen/ops/fft_fftfreq_native.h>
656
+ #include <ATen/ops/fft_fftn_native.h>
657
+ #include <ATen/ops/fft_fftshift_native.h>
658
+ #include <ATen/ops/fft_hfft_native.h>
659
+ #include <ATen/ops/fft_hfft2_native.h>
660
+ #include <ATen/ops/fft_hfftn_native.h>
661
+ #include <ATen/ops/fft_ifft_native.h>
662
+ #include <ATen/ops/fft_ifft2_native.h>
663
+ #include <ATen/ops/fft_ifftn_native.h>
664
+ #include <ATen/ops/fft_ifftshift_native.h>
665
+ #include <ATen/ops/fft_ihfft_native.h>
666
+ #include <ATen/ops/fft_ihfft2_native.h>
667
+ #include <ATen/ops/fft_ihfftn_native.h>
668
+ #include <ATen/ops/fft_irfft_native.h>
669
+ #include <ATen/ops/fft_irfft2_native.h>
670
+ #include <ATen/ops/fft_irfftn_native.h>
671
+ #include <ATen/ops/fft_rfft_native.h>
672
+ #include <ATen/ops/fft_rfft2_native.h>
673
+ #include <ATen/ops/fft_rfftfreq_native.h>
674
+ #include <ATen/ops/fft_rfftn_native.h>
675
+ #include <ATen/ops/fill_native.h>
676
+ #include <ATen/ops/fill_diagonal_native.h>
677
+ #include <ATen/ops/fix_native.h>
678
+ #include <ATen/ops/flatten_native.h>
679
+ #include <ATen/ops/flatten_dense_tensors_native.h>
680
+ #include <ATen/ops/flip_native.h>
681
+ #include <ATen/ops/fliplr_native.h>
682
+ #include <ATen/ops/flipud_native.h>
683
+ #include <ATen/ops/float_power_native.h>
684
+ #include <ATen/ops/floor_native.h>
685
+ #include <ATen/ops/floor_divide_native.h>
686
+ #include <ATen/ops/fmax_native.h>
687
+ #include <ATen/ops/fmin_native.h>
688
+ #include <ATen/ops/fmod_native.h>
689
+ #include <ATen/ops/frac_native.h>
690
+ #include <ATen/ops/fractional_max_pool2d_native.h>
691
+ #include <ATen/ops/fractional_max_pool2d_backward_native.h>
692
+ #include <ATen/ops/fractional_max_pool3d_native.h>
693
+ #include <ATen/ops/fractional_max_pool3d_backward_native.h>
694
+ #include <ATen/ops/frexp_native.h>
695
+ #include <ATen/ops/frobenius_norm_native.h>
696
+ #include <ATen/ops/from_file_native.h>
697
+ #include <ATen/ops/full_native.h>
698
+ #include <ATen/ops/full_like_native.h>
699
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant_native.h>
700
+ #include <ATen/ops/gather_native.h>
701
+ #include <ATen/ops/gather_backward_native.h>
702
+ #include <ATen/ops/gcd_native.h>
703
+ #include <ATen/ops/ge_native.h>
704
+ #include <ATen/ops/gelu_native.h>
705
+ #include <ATen/ops/gelu_backward_native.h>
706
+ #include <ATen/ops/geometric_native.h>
707
+ #include <ATen/ops/geqrf_native.h>
708
+ #include <ATen/ops/ger_native.h>
709
+ #include <ATen/ops/glu_native.h>
710
+ #include <ATen/ops/glu_backward_native.h>
711
+ #include <ATen/ops/glu_backward_jvp_native.h>
712
+ #include <ATen/ops/glu_jvp_native.h>
713
+ #include <ATen/ops/gradient_native.h>
714
+ #include <ATen/ops/greater_native.h>
715
+ #include <ATen/ops/greater_equal_native.h>
716
+ #include <ATen/ops/grid_sampler_native.h>
717
+ #include <ATen/ops/grid_sampler_2d_native.h>
718
+ #include <ATen/ops/grid_sampler_2d_backward_native.h>
719
+ #include <ATen/ops/grid_sampler_3d_native.h>
720
+ #include <ATen/ops/grid_sampler_3d_backward_native.h>
721
+ #include <ATen/ops/group_norm_native.h>
722
+ #include <ATen/ops/gru_native.h>
723
+ #include <ATen/ops/gru_cell_native.h>
724
+ #include <ATen/ops/gt_native.h>
725
+ #include <ATen/ops/hamming_window_native.h>
726
+ #include <ATen/ops/hann_window_native.h>
727
+ #include <ATen/ops/hardshrink_native.h>
728
+ #include <ATen/ops/hardshrink_backward_native.h>
729
+ #include <ATen/ops/hardsigmoid_native.h>
730
+ #include <ATen/ops/hardsigmoid_backward_native.h>
731
+ #include <ATen/ops/hardswish_native.h>
732
+ #include <ATen/ops/hardswish_backward_native.h>
733
+ #include <ATen/ops/hardtanh_native.h>
734
+ #include <ATen/ops/hardtanh_backward_native.h>
735
+ #include <ATen/ops/heaviside_native.h>
736
+ #include <ATen/ops/hinge_embedding_loss_native.h>
737
+ #include <ATen/ops/histc_native.h>
738
+ #include <ATen/ops/histogram_native.h>
739
+ #include <ATen/ops/histogramdd_native.h>
740
+ #include <ATen/ops/hsplit_native.h>
741
+ #include <ATen/ops/hspmm_native.h>
742
+ #include <ATen/ops/hstack_native.h>
743
+ #include <ATen/ops/huber_loss_native.h>
744
+ #include <ATen/ops/huber_loss_backward_native.h>
745
+ #include <ATen/ops/hypot_native.h>
746
+ #include <ATen/ops/i0_native.h>
747
+ #include <ATen/ops/igamma_native.h>
748
+ #include <ATen/ops/igammac_native.h>
749
+ #include <ATen/ops/im2col_native.h>
750
+ #include <ATen/ops/imag_native.h>
751
+ #include <ATen/ops/index_native.h>
752
+ #include <ATen/ops/index_add_native.h>
753
+ #include <ATen/ops/index_copy_native.h>
754
+ #include <ATen/ops/index_fill_native.h>
755
+ #include <ATen/ops/index_put_native.h>
756
+ #include <ATen/ops/index_reduce_native.h>
757
+ #include <ATen/ops/index_select_native.h>
758
+ #include <ATen/ops/index_select_backward_native.h>
759
+ #include <ATen/ops/indices_native.h>
760
+ #include <ATen/ops/indices_copy_native.h>
761
+ #include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
762
+ #include <ATen/ops/inner_native.h>
763
+ #include <ATen/ops/instance_norm_native.h>
764
+ #include <ATen/ops/int_repr_native.h>
765
+ #include <ATen/ops/inverse_native.h>
766
+ #include <ATen/ops/is_coalesced_native.h>
767
+ #include <ATen/ops/is_complex_native.h>
768
+ #include <ATen/ops/is_conj_native.h>
769
+ #include <ATen/ops/is_distributed_native.h>
770
+ #include <ATen/ops/is_floating_point_native.h>
771
+ #include <ATen/ops/is_inference_native.h>
772
+ #include <ATen/ops/is_leaf_native.h>
773
+ #include <ATen/ops/is_neg_native.h>
774
+ #include <ATen/ops/is_nonzero_native.h>
775
+ #include <ATen/ops/is_pinned_native.h>
776
+ #include <ATen/ops/is_same_size_native.h>
777
+ #include <ATen/ops/is_set_to_native.h>
778
+ #include <ATen/ops/is_signed_native.h>
779
+ #include <ATen/ops/is_vulkan_available_native.h>
780
+ #include <ATen/ops/isclose_native.h>
781
+ #include <ATen/ops/isfinite_native.h>
782
+ #include <ATen/ops/isin_native.h>
783
+ #include <ATen/ops/isinf_native.h>
784
+ #include <ATen/ops/isnan_native.h>
785
+ #include <ATen/ops/isneginf_native.h>
786
+ #include <ATen/ops/isposinf_native.h>
787
+ #include <ATen/ops/isreal_native.h>
788
+ #include <ATen/ops/istft_native.h>
789
+ #include <ATen/ops/item_native.h>
790
+ #include <ATen/ops/kaiser_window_native.h>
791
+ #include <ATen/ops/kl_div_native.h>
792
+ #include <ATen/ops/kron_native.h>
793
+ #include <ATen/ops/kthvalue_native.h>
794
+ #include <ATen/ops/l1_loss_native.h>
795
+ #include <ATen/ops/layer_norm_native.h>
796
+ #include <ATen/ops/lcm_native.h>
797
+ #include <ATen/ops/ldexp_native.h>
798
+ #include <ATen/ops/le_native.h>
799
+ #include <ATen/ops/leaky_relu_native.h>
800
+ #include <ATen/ops/leaky_relu_backward_native.h>
801
+ #include <ATen/ops/lerp_native.h>
802
+ #include <ATen/ops/less_native.h>
803
+ #include <ATen/ops/less_equal_native.h>
804
+ #include <ATen/ops/lgamma_native.h>
805
+ #include <ATen/ops/lift_native.h>
806
+ #include <ATen/ops/lift_fresh_native.h>
807
+ #include <ATen/ops/lift_fresh_copy_native.h>
808
+ #include <ATen/ops/linalg_cholesky_native.h>
809
+ #include <ATen/ops/linalg_cholesky_ex_native.h>
810
+ #include <ATen/ops/linalg_cond_native.h>
811
+ #include <ATen/ops/linalg_cross_native.h>
812
+ #include <ATen/ops/linalg_det_native.h>
813
+ #include <ATen/ops/linalg_diagonal_native.h>
814
+ #include <ATen/ops/linalg_eig_native.h>
815
+ #include <ATen/ops/linalg_eigh_native.h>
816
+ #include <ATen/ops/linalg_eigvals_native.h>
817
+ #include <ATen/ops/linalg_eigvalsh_native.h>
818
+ #include <ATen/ops/linalg_householder_product_native.h>
819
+ #include <ATen/ops/linalg_inv_native.h>
820
+ #include <ATen/ops/linalg_inv_ex_native.h>
821
+ #include <ATen/ops/linalg_ldl_factor_native.h>
822
+ #include <ATen/ops/linalg_ldl_factor_ex_native.h>
823
+ #include <ATen/ops/linalg_ldl_solve_native.h>
824
+ #include <ATen/ops/linalg_lstsq_native.h>
825
+ #include <ATen/ops/linalg_lu_native.h>
826
+ #include <ATen/ops/linalg_lu_factor_native.h>
827
+ #include <ATen/ops/linalg_lu_factor_ex_native.h>
828
+ #include <ATen/ops/linalg_lu_solve_native.h>
829
+ #include <ATen/ops/linalg_matmul_native.h>
830
+ #include <ATen/ops/linalg_matrix_exp_native.h>
831
+ #include <ATen/ops/linalg_matrix_norm_native.h>
832
+ #include <ATen/ops/linalg_matrix_power_native.h>
833
+ #include <ATen/ops/linalg_matrix_rank_native.h>
834
+ #include <ATen/ops/linalg_multi_dot_native.h>
835
+ #include <ATen/ops/linalg_norm_native.h>
836
+ #include <ATen/ops/linalg_pinv_native.h>
837
+ #include <ATen/ops/linalg_qr_native.h>
838
+ #include <ATen/ops/linalg_slogdet_native.h>
839
+ #include <ATen/ops/linalg_solve_native.h>
840
+ #include <ATen/ops/linalg_solve_ex_native.h>
841
+ #include <ATen/ops/linalg_solve_triangular_native.h>
842
+ #include <ATen/ops/linalg_svd_native.h>
843
+ #include <ATen/ops/linalg_svdvals_native.h>
844
+ #include <ATen/ops/linalg_tensorinv_native.h>
845
+ #include <ATen/ops/linalg_tensorsolve_native.h>
846
+ #include <ATen/ops/linalg_vander_native.h>
847
+ #include <ATen/ops/linalg_vecdot_native.h>
848
+ #include <ATen/ops/linalg_vector_norm_native.h>
849
+ #include <ATen/ops/linear_native.h>
850
+ #include <ATen/ops/linear_backward_native.h>
851
+ #include <ATen/ops/linspace_native.h>
852
+ #include <ATen/ops/log_native.h>
853
+ #include <ATen/ops/log10_native.h>
854
+ #include <ATen/ops/log1p_native.h>
855
+ #include <ATen/ops/log2_native.h>
856
+ #include <ATen/ops/log_normal_native.h>
857
+ #include <ATen/ops/log_sigmoid_native.h>
858
+ #include <ATen/ops/log_sigmoid_backward_native.h>
859
+ #include <ATen/ops/log_sigmoid_forward_native.h>
860
+ #include <ATen/ops/log_softmax_native.h>
861
+ #include <ATen/ops/logaddexp_native.h>
862
+ #include <ATen/ops/logaddexp2_native.h>
863
+ #include <ATen/ops/logcumsumexp_native.h>
864
+ #include <ATen/ops/logdet_native.h>
865
+ #include <ATen/ops/logical_and_native.h>
866
+ #include <ATen/ops/logical_not_native.h>
867
+ #include <ATen/ops/logical_or_native.h>
868
+ #include <ATen/ops/logical_xor_native.h>
869
+ #include <ATen/ops/logit_native.h>
870
+ #include <ATen/ops/logit_backward_native.h>
871
+ #include <ATen/ops/logspace_native.h>
872
+ #include <ATen/ops/logsumexp_native.h>
873
+ #include <ATen/ops/lshift_native.h>
874
+ #include <ATen/ops/lstm_native.h>
875
+ #include <ATen/ops/lstm_cell_native.h>
876
+ #include <ATen/ops/lstm_mps_backward_native.h>
877
+ #include <ATen/ops/lt_native.h>
878
+ #include <ATen/ops/lu_solve_native.h>
879
+ #include <ATen/ops/lu_unpack_native.h>
880
+ #include <ATen/ops/mH_native.h>
881
+ #include <ATen/ops/mT_native.h>
882
+ #include <ATen/ops/margin_ranking_loss_native.h>
883
+ #include <ATen/ops/masked_fill_native.h>
884
+ #include <ATen/ops/masked_scatter_native.h>
885
+ #include <ATen/ops/masked_scatter_backward_native.h>
886
+ #include <ATen/ops/masked_select_native.h>
887
+ #include <ATen/ops/masked_select_backward_native.h>
888
+ #include <ATen/ops/matmul_native.h>
889
+ #include <ATen/ops/matmul_backward_native.h>
890
+ #include <ATen/ops/matrix_H_native.h>
891
+ #include <ATen/ops/matrix_exp_native.h>
892
+ #include <ATen/ops/matrix_exp_backward_native.h>
893
+ #include <ATen/ops/matrix_power_native.h>
894
+ #include <ATen/ops/max_native.h>
895
+ #include <ATen/ops/max_pool1d_native.h>
896
+ #include <ATen/ops/max_pool1d_with_indices_native.h>
897
+ #include <ATen/ops/max_pool2d_native.h>
898
+ #include <ATen/ops/max_pool2d_backward_native.h>
899
+ #include <ATen/ops/max_pool2d_with_indices_native.h>
900
+ #include <ATen/ops/max_pool2d_with_indices_backward_native.h>
901
+ #include <ATen/ops/max_pool3d_native.h>
902
+ #include <ATen/ops/max_pool3d_with_indices_native.h>
903
+ #include <ATen/ops/max_pool3d_with_indices_backward_native.h>
904
+ #include <ATen/ops/max_unpool2d_native.h>
905
+ #include <ATen/ops/max_unpool3d_native.h>
906
+ #include <ATen/ops/maximum_native.h>
907
+ #include <ATen/ops/mean_native.h>
908
+ #include <ATen/ops/median_native.h>
909
+ #include <ATen/ops/meshgrid_native.h>
910
+ #include <ATen/ops/min_native.h>
911
+ #include <ATen/ops/minimum_native.h>
912
+ #include <ATen/ops/miopen_batch_norm_native.h>
913
+ #include <ATen/ops/miopen_batch_norm_backward_native.h>
914
+ #include <ATen/ops/miopen_convolution_native.h>
915
+ #include <ATen/ops/miopen_convolution_add_relu_native.h>
916
+ #include <ATen/ops/miopen_convolution_relu_native.h>
917
+ #include <ATen/ops/miopen_convolution_transpose_native.h>
918
+ #include <ATen/ops/miopen_depthwise_convolution_native.h>
919
+ #include <ATen/ops/miopen_rnn_native.h>
920
+ #include <ATen/ops/miopen_rnn_backward_native.h>
921
+ #include <ATen/ops/mish_native.h>
922
+ #include <ATen/ops/mish_backward_native.h>
923
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_native.h>
924
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_native.h>
925
+ #include <ATen/ops/mkldnn_convolution_native.h>
926
+ #include <ATen/ops/mkldnn_linear_native.h>
927
+ #include <ATen/ops/mkldnn_linear_backward_native.h>
928
+ #include <ATen/ops/mkldnn_linear_backward_input_native.h>
929
+ #include <ATen/ops/mkldnn_linear_backward_weights_native.h>
930
+ #include <ATen/ops/mkldnn_max_pool2d_native.h>
931
+ #include <ATen/ops/mkldnn_max_pool2d_backward_native.h>
932
+ #include <ATen/ops/mkldnn_max_pool3d_native.h>
933
+ #include <ATen/ops/mkldnn_max_pool3d_backward_native.h>
934
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
935
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
936
+ #include <ATen/ops/mkldnn_rnn_layer_native.h>
937
+ #include <ATen/ops/mkldnn_rnn_layer_backward_native.h>
938
+ #include <ATen/ops/mm_native.h>
939
+ #include <ATen/ops/mode_native.h>
940
+ #include <ATen/ops/moveaxis_native.h>
941
+ #include <ATen/ops/movedim_native.h>
942
+ #include <ATen/ops/mps_convolution_backward_native.h>
943
+ #include <ATen/ops/mps_convolution_transpose_backward_native.h>
944
+ #include <ATen/ops/mse_loss_native.h>
945
+ #include <ATen/ops/mse_loss_backward_native.h>
946
+ #include <ATen/ops/msort_native.h>
947
+ #include <ATen/ops/mul_native.h>
948
+ #include <ATen/ops/multi_margin_loss_native.h>
949
+ #include <ATen/ops/multi_margin_loss_backward_native.h>
950
+ #include <ATen/ops/multilabel_margin_loss_native.h>
951
+ #include <ATen/ops/multilabel_margin_loss_backward_native.h>
952
+ #include <ATen/ops/multilabel_margin_loss_forward_native.h>
953
+ #include <ATen/ops/multinomial_native.h>
954
+ #include <ATen/ops/multiply_native.h>
955
+ #include <ATen/ops/mv_native.h>
956
+ #include <ATen/ops/mvlgamma_native.h>
957
+ #include <ATen/ops/nan_to_num_native.h>
958
+ #include <ATen/ops/nanmean_native.h>
959
+ #include <ATen/ops/nanmedian_native.h>
960
+ #include <ATen/ops/nanquantile_native.h>
961
+ #include <ATen/ops/nansum_native.h>
962
+ #include <ATen/ops/narrow_native.h>
963
+ #include <ATen/ops/narrow_copy_native.h>
964
+ #include <ATen/ops/native_batch_norm_native.h>
965
+ #include <ATen/ops/native_batch_norm_backward_native.h>
966
+ #include <ATen/ops/native_channel_shuffle_native.h>
967
+ #include <ATen/ops/native_dropout_native.h>
968
+ #include <ATen/ops/native_dropout_backward_native.h>
969
+ #include <ATen/ops/native_group_norm_native.h>
970
+ #include <ATen/ops/native_group_norm_backward_native.h>
971
+ #include <ATen/ops/native_layer_norm_native.h>
972
+ #include <ATen/ops/native_layer_norm_backward_native.h>
973
+ #include <ATen/ops/native_norm_native.h>
974
+ #include <ATen/ops/ne_native.h>
975
+ #include <ATen/ops/neg_native.h>
976
+ #include <ATen/ops/negative_native.h>
977
+ #include <ATen/ops/nested_to_padded_tensor_native.h>
978
+ #include <ATen/ops/new_empty_native.h>
979
+ #include <ATen/ops/new_empty_strided_native.h>
980
+ #include <ATen/ops/new_full_native.h>
981
+ #include <ATen/ops/new_ones_native.h>
982
+ #include <ATen/ops/new_zeros_native.h>
983
+ #include <ATen/ops/nextafter_native.h>
984
+ #include <ATen/ops/nll_loss_native.h>
985
+ #include <ATen/ops/nll_loss2d_native.h>
986
+ #include <ATen/ops/nll_loss2d_backward_native.h>
987
+ #include <ATen/ops/nll_loss2d_forward_native.h>
988
+ #include <ATen/ops/nll_loss_backward_native.h>
989
+ #include <ATen/ops/nll_loss_forward_native.h>
990
+ #include <ATen/ops/nll_loss_nd_native.h>
991
+ #include <ATen/ops/nonzero_native.h>
992
+ #include <ATen/ops/nonzero_numpy_native.h>
993
+ #include <ATen/ops/nonzero_static_native.h>
994
+ #include <ATen/ops/norm_native.h>
995
+ #include <ATen/ops/norm_except_dim_native.h>
996
+ #include <ATen/ops/normal_native.h>
997
+ #include <ATen/ops/not_equal_native.h>
998
+ #include <ATen/ops/nuclear_norm_native.h>
999
+ #include <ATen/ops/numpy_T_native.h>
1000
+ #include <ATen/ops/one_hot_native.h>
1001
+ #include <ATen/ops/ones_native.h>
1002
+ #include <ATen/ops/ones_like_native.h>
1003
+ #include <ATen/ops/or_native.h>
1004
+ #include <ATen/ops/orgqr_native.h>
1005
+ #include <ATen/ops/ormqr_native.h>
1006
+ #include <ATen/ops/outer_native.h>
1007
+ #include <ATen/ops/output_nr_native.h>
1008
+ #include <ATen/ops/pad_native.h>
1009
+ #include <ATen/ops/pad_sequence_native.h>
1010
+ #include <ATen/ops/pairwise_distance_native.h>
1011
+ #include <ATen/ops/pdist_native.h>
1012
+ #include <ATen/ops/permute_native.h>
1013
+ #include <ATen/ops/permute_copy_native.h>
1014
+ #include <ATen/ops/pin_memory_native.h>
1015
+ #include <ATen/ops/pinverse_native.h>
1016
+ #include <ATen/ops/pixel_shuffle_native.h>
1017
+ #include <ATen/ops/pixel_unshuffle_native.h>
1018
+ #include <ATen/ops/poisson_native.h>
1019
+ #include <ATen/ops/poisson_nll_loss_native.h>
1020
+ #include <ATen/ops/polar_native.h>
1021
+ #include <ATen/ops/polygamma_native.h>
1022
+ #include <ATen/ops/positive_native.h>
1023
+ #include <ATen/ops/pow_native.h>
1024
+ #include <ATen/ops/prelu_native.h>
1025
+ #include <ATen/ops/prod_native.h>
1026
+ #include <ATen/ops/promote_types_native.h>
1027
+ #include <ATen/ops/put_native.h>
1028
+ #include <ATen/ops/q_per_channel_axis_native.h>
1029
+ #include <ATen/ops/q_per_channel_scales_native.h>
1030
+ #include <ATen/ops/q_per_channel_zero_points_native.h>
1031
+ #include <ATen/ops/q_scale_native.h>
1032
+ #include <ATen/ops/q_zero_point_native.h>
1033
+ #include <ATen/ops/qr_native.h>
1034
+ #include <ATen/ops/qscheme_native.h>
1035
+ #include <ATen/ops/quantile_native.h>
1036
+ #include <ATen/ops/quantize_per_channel_native.h>
1037
+ #include <ATen/ops/quantize_per_tensor_native.h>
1038
+ #include <ATen/ops/quantize_per_tensor_dynamic_native.h>
1039
+ #include <ATen/ops/quantized_batch_norm_native.h>
1040
+ #include <ATen/ops/quantized_gru_cell_native.h>
1041
+ #include <ATen/ops/quantized_lstm_cell_native.h>
1042
+ #include <ATen/ops/quantized_max_pool1d_native.h>
1043
+ #include <ATen/ops/quantized_max_pool2d_native.h>
1044
+ #include <ATen/ops/quantized_max_pool3d_native.h>
1045
+ #include <ATen/ops/quantized_rnn_relu_cell_native.h>
1046
+ #include <ATen/ops/quantized_rnn_tanh_cell_native.h>
1047
+ #include <ATen/ops/rad2deg_native.h>
1048
+ #include <ATen/ops/rand_native.h>
1049
+ #include <ATen/ops/rand_like_native.h>
1050
+ #include <ATen/ops/randint_native.h>
1051
+ #include <ATen/ops/randint_like_native.h>
1052
+ #include <ATen/ops/randn_native.h>
1053
+ #include <ATen/ops/randn_like_native.h>
1054
+ #include <ATen/ops/random_native.h>
1055
+ #include <ATen/ops/randperm_native.h>
1056
+ #include <ATen/ops/range_native.h>
1057
+ #include <ATen/ops/ravel_native.h>
1058
+ #include <ATen/ops/real_native.h>
1059
+ #include <ATen/ops/reciprocal_native.h>
1060
+ #include <ATen/ops/record_stream_native.h>
1061
+ #include <ATen/ops/refine_names_native.h>
1062
+ #include <ATen/ops/reflection_pad1d_native.h>
1063
+ #include <ATen/ops/reflection_pad1d_backward_native.h>
1064
+ #include <ATen/ops/reflection_pad2d_native.h>
1065
+ #include <ATen/ops/reflection_pad2d_backward_native.h>
1066
+ #include <ATen/ops/reflection_pad3d_native.h>
1067
+ #include <ATen/ops/reflection_pad3d_backward_native.h>
1068
+ #include <ATen/ops/relu_native.h>
1069
+ #include <ATen/ops/relu6_native.h>
1070
+ #include <ATen/ops/remainder_native.h>
1071
+ #include <ATen/ops/rename_native.h>
1072
+ #include <ATen/ops/renorm_native.h>
1073
+ #include <ATen/ops/repeat_native.h>
1074
+ #include <ATen/ops/repeat_interleave_native.h>
1075
+ #include <ATen/ops/replication_pad1d_native.h>
1076
+ #include <ATen/ops/replication_pad1d_backward_native.h>
1077
+ #include <ATen/ops/replication_pad2d_native.h>
1078
+ #include <ATen/ops/replication_pad2d_backward_native.h>
1079
+ #include <ATen/ops/replication_pad3d_native.h>
1080
+ #include <ATen/ops/replication_pad3d_backward_native.h>
1081
+ #include <ATen/ops/requires_grad_native.h>
1082
+ #include <ATen/ops/reshape_native.h>
1083
+ #include <ATen/ops/reshape_as_native.h>
1084
+ #include <ATen/ops/resize_native.h>
1085
+ #include <ATen/ops/resize_as_native.h>
1086
+ #include <ATen/ops/resize_as_sparse_native.h>
1087
+ #include <ATen/ops/resolve_conj_native.h>
1088
+ #include <ATen/ops/resolve_neg_native.h>
1089
+ #include <ATen/ops/result_type_native.h>
1090
+ #include <ATen/ops/retain_grad_native.h>
1091
+ #include <ATen/ops/retains_grad_native.h>
1092
+ #include <ATen/ops/rms_norm_native.h>
1093
+ #include <ATen/ops/rnn_relu_native.h>
1094
+ #include <ATen/ops/rnn_relu_cell_native.h>
1095
+ #include <ATen/ops/rnn_tanh_native.h>
1096
+ #include <ATen/ops/rnn_tanh_cell_native.h>
1097
+ #include <ATen/ops/roll_native.h>
1098
+ #include <ATen/ops/rot90_native.h>
1099
+ #include <ATen/ops/round_native.h>
1100
+ #include <ATen/ops/row_indices_native.h>
1101
+ #include <ATen/ops/row_indices_copy_native.h>
1102
+ #include <ATen/ops/row_stack_native.h>
1103
+ #include <ATen/ops/rrelu_native.h>
1104
+ #include <ATen/ops/rrelu_with_noise_native.h>
1105
+ #include <ATen/ops/rrelu_with_noise_backward_native.h>
1106
+ #include <ATen/ops/rshift_native.h>
1107
+ #include <ATen/ops/rsqrt_native.h>
1108
+ #include <ATen/ops/rsub_native.h>
1109
+ #include <ATen/ops/scalar_tensor_native.h>
1110
+ #include <ATen/ops/scaled_dot_product_attention_native.h>
1111
+ #include <ATen/ops/scatter_native.h>
1112
+ #include <ATen/ops/scatter_add_native.h>
1113
+ #include <ATen/ops/scatter_reduce_native.h>
1114
+ #include <ATen/ops/searchsorted_native.h>
1115
+ #include <ATen/ops/segment_reduce_native.h>
1116
+ #include <ATen/ops/select_native.h>
1117
+ #include <ATen/ops/select_backward_native.h>
1118
+ #include <ATen/ops/select_copy_native.h>
1119
+ #include <ATen/ops/select_scatter_native.h>
1120
+ #include <ATen/ops/selu_native.h>
1121
+ #include <ATen/ops/set_native.h>
1122
+ #include <ATen/ops/set_data_native.h>
1123
+ #include <ATen/ops/sgn_native.h>
1124
+ #include <ATen/ops/sigmoid_native.h>
1125
+ #include <ATen/ops/sigmoid_backward_native.h>
1126
+ #include <ATen/ops/sign_native.h>
1127
+ #include <ATen/ops/signbit_native.h>
1128
+ #include <ATen/ops/silu_native.h>
1129
+ #include <ATen/ops/silu_backward_native.h>
1130
+ #include <ATen/ops/sin_native.h>
1131
+ #include <ATen/ops/sinc_native.h>
1132
+ #include <ATen/ops/sinh_native.h>
1133
+ #include <ATen/ops/size_native.h>
1134
+ #include <ATen/ops/slice_native.h>
1135
+ #include <ATen/ops/slice_backward_native.h>
1136
+ #include <ATen/ops/slice_copy_native.h>
1137
+ #include <ATen/ops/slice_inverse_native.h>
1138
+ #include <ATen/ops/slice_scatter_native.h>
1139
+ #include <ATen/ops/slogdet_native.h>
1140
+ #include <ATen/ops/slow_conv3d_native.h>
1141
+ #include <ATen/ops/slow_conv3d_forward_native.h>
1142
+ #include <ATen/ops/slow_conv_dilated2d_native.h>
1143
+ #include <ATen/ops/slow_conv_dilated3d_native.h>
1144
+ #include <ATen/ops/slow_conv_transpose2d_native.h>
1145
+ #include <ATen/ops/slow_conv_transpose3d_native.h>
1146
+ #include <ATen/ops/smm_native.h>
1147
+ #include <ATen/ops/smooth_l1_loss_native.h>
1148
+ #include <ATen/ops/smooth_l1_loss_backward_native.h>
1149
+ #include <ATen/ops/soft_margin_loss_native.h>
1150
+ #include <ATen/ops/soft_margin_loss_backward_native.h>
1151
+ #include <ATen/ops/softmax_native.h>
1152
+ #include <ATen/ops/softplus_native.h>
1153
+ #include <ATen/ops/softplus_backward_native.h>
1154
+ #include <ATen/ops/softshrink_native.h>
1155
+ #include <ATen/ops/softshrink_backward_native.h>
1156
+ #include <ATen/ops/sort_native.h>
1157
+ #include <ATen/ops/sparse_bsc_tensor_native.h>
1158
+ #include <ATen/ops/sparse_bsr_tensor_native.h>
1159
+ #include <ATen/ops/sparse_compressed_tensor_native.h>
1160
+ #include <ATen/ops/sparse_coo_tensor_native.h>
1161
+ #include <ATen/ops/sparse_csc_tensor_native.h>
1162
+ #include <ATen/ops/sparse_csr_tensor_native.h>
1163
+ #include <ATen/ops/sparse_dim_native.h>
1164
+ #include <ATen/ops/sparse_mask_native.h>
1165
+ #include <ATen/ops/sparse_resize_native.h>
1166
+ #include <ATen/ops/sparse_resize_and_clear_native.h>
1167
+ #include <ATen/ops/sparse_sampled_addmm_native.h>
1168
+ #include <ATen/ops/special_airy_ai_native.h>
1169
+ #include <ATen/ops/special_bessel_j0_native.h>
1170
+ #include <ATen/ops/special_bessel_j1_native.h>
1171
+ #include <ATen/ops/special_bessel_y0_native.h>
1172
+ #include <ATen/ops/special_bessel_y1_native.h>
1173
+ #include <ATen/ops/special_chebyshev_polynomial_t_native.h>
1174
+ #include <ATen/ops/special_chebyshev_polynomial_u_native.h>
1175
+ #include <ATen/ops/special_chebyshev_polynomial_v_native.h>
1176
+ #include <ATen/ops/special_chebyshev_polynomial_w_native.h>
1177
+ #include <ATen/ops/special_digamma_native.h>
1178
+ #include <ATen/ops/special_entr_native.h>
1179
+ #include <ATen/ops/special_erf_native.h>
1180
+ #include <ATen/ops/special_erfc_native.h>
1181
+ #include <ATen/ops/special_erfcx_native.h>
1182
+ #include <ATen/ops/special_erfinv_native.h>
1183
+ #include <ATen/ops/special_exp2_native.h>
1184
+ #include <ATen/ops/special_expit_native.h>
1185
+ #include <ATen/ops/special_expm1_native.h>
1186
+ #include <ATen/ops/special_gammainc_native.h>
1187
+ #include <ATen/ops/special_gammaincc_native.h>
1188
+ #include <ATen/ops/special_gammaln_native.h>
1189
+ #include <ATen/ops/special_hermite_polynomial_h_native.h>
1190
+ #include <ATen/ops/special_hermite_polynomial_he_native.h>
1191
+ #include <ATen/ops/special_i0_native.h>
1192
+ #include <ATen/ops/special_i0e_native.h>
1193
+ #include <ATen/ops/special_i1_native.h>
1194
+ #include <ATen/ops/special_i1e_native.h>
1195
+ #include <ATen/ops/special_laguerre_polynomial_l_native.h>
1196
+ #include <ATen/ops/special_legendre_polynomial_p_native.h>
1197
+ #include <ATen/ops/special_log1p_native.h>
1198
+ #include <ATen/ops/special_log_ndtr_native.h>
1199
+ #include <ATen/ops/special_log_softmax_native.h>
1200
+ #include <ATen/ops/special_logit_native.h>
1201
+ #include <ATen/ops/special_logsumexp_native.h>
1202
+ #include <ATen/ops/special_modified_bessel_i0_native.h>
1203
+ #include <ATen/ops/special_modified_bessel_i1_native.h>
1204
+ #include <ATen/ops/special_modified_bessel_k0_native.h>
1205
+ #include <ATen/ops/special_modified_bessel_k1_native.h>
1206
+ #include <ATen/ops/special_multigammaln_native.h>
1207
+ #include <ATen/ops/special_ndtr_native.h>
1208
+ #include <ATen/ops/special_ndtri_native.h>
1209
+ #include <ATen/ops/special_polygamma_native.h>
1210
+ #include <ATen/ops/special_psi_native.h>
1211
+ #include <ATen/ops/special_round_native.h>
1212
+ #include <ATen/ops/special_scaled_modified_bessel_k0_native.h>
1213
+ #include <ATen/ops/special_scaled_modified_bessel_k1_native.h>
1214
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_native.h>
1215
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_native.h>
1216
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_native.h>
1217
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_native.h>
1218
+ #include <ATen/ops/special_sinc_native.h>
1219
+ #include <ATen/ops/special_softmax_native.h>
1220
+ #include <ATen/ops/special_spherical_bessel_j0_native.h>
1221
+ #include <ATen/ops/special_xlog1py_native.h>
1222
+ #include <ATen/ops/special_xlogy_native.h>
1223
+ #include <ATen/ops/special_zeta_native.h>
1224
+ #include <ATen/ops/split_native.h>
1225
+ #include <ATen/ops/split_copy_native.h>
1226
+ #include <ATen/ops/split_with_sizes_native.h>
1227
+ #include <ATen/ops/split_with_sizes_copy_native.h>
1228
+ #include <ATen/ops/sqrt_native.h>
1229
+ #include <ATen/ops/square_native.h>
1230
+ #include <ATen/ops/squeeze_native.h>
1231
+ #include <ATen/ops/squeeze_copy_native.h>
1232
+ #include <ATen/ops/sspaddmm_native.h>
1233
+ #include <ATen/ops/stack_native.h>
1234
+ #include <ATen/ops/std_native.h>
1235
+ #include <ATen/ops/std_mean_native.h>
1236
+ #include <ATen/ops/stft_native.h>
1237
+ #include <ATen/ops/stride_native.h>
1238
+ #include <ATen/ops/sub_native.h>
1239
+ #include <ATen/ops/subtract_native.h>
1240
+ #include <ATen/ops/sum_native.h>
1241
+ #include <ATen/ops/sum_to_size_native.h>
1242
+ #include <ATen/ops/svd_native.h>
1243
+ #include <ATen/ops/swapaxes_native.h>
1244
+ #include <ATen/ops/swapdims_native.h>
1245
+ #include <ATen/ops/sym_constrain_range_native.h>
1246
+ #include <ATen/ops/sym_constrain_range_for_size_native.h>
1247
+ #include <ATen/ops/sym_numel_native.h>
1248
+ #include <ATen/ops/sym_size_native.h>
1249
+ #include <ATen/ops/sym_storage_offset_native.h>
1250
+ #include <ATen/ops/sym_stride_native.h>
1251
+ #include <ATen/ops/t_native.h>
1252
+ #include <ATen/ops/t_copy_native.h>
1253
+ #include <ATen/ops/take_native.h>
1254
+ #include <ATen/ops/take_along_dim_native.h>
1255
+ #include <ATen/ops/tan_native.h>
1256
+ #include <ATen/ops/tanh_native.h>
1257
+ #include <ATen/ops/tanh_backward_native.h>
1258
+ #include <ATen/ops/tensor_split_native.h>
1259
+ #include <ATen/ops/tensordot_native.h>
1260
+ #include <ATen/ops/thnn_conv2d_native.h>
1261
+ #include <ATen/ops/threshold_native.h>
1262
+ #include <ATen/ops/threshold_backward_native.h>
1263
+ #include <ATen/ops/tile_native.h>
1264
+ #include <ATen/ops/to_native.h>
1265
+ #include <ATen/ops/to_dense_native.h>
1266
+ #include <ATen/ops/to_dense_backward_native.h>
1267
+ #include <ATen/ops/to_mkldnn_native.h>
1268
+ #include <ATen/ops/to_mkldnn_backward_native.h>
1269
+ #include <ATen/ops/to_padded_tensor_native.h>
1270
+ #include <ATen/ops/to_sparse_native.h>
1271
+ #include <ATen/ops/to_sparse_bsc_native.h>
1272
+ #include <ATen/ops/to_sparse_bsr_native.h>
1273
+ #include <ATen/ops/to_sparse_csc_native.h>
1274
+ #include <ATen/ops/to_sparse_csr_native.h>
1275
+ #include <ATen/ops/topk_native.h>
1276
+ #include <ATen/ops/trace_native.h>
1277
+ #include <ATen/ops/trace_backward_native.h>
1278
+ #include <ATen/ops/transpose_native.h>
1279
+ #include <ATen/ops/transpose_copy_native.h>
1280
+ #include <ATen/ops/trapezoid_native.h>
1281
+ #include <ATen/ops/trapz_native.h>
1282
+ #include <ATen/ops/triangular_solve_native.h>
1283
+ #include <ATen/ops/tril_native.h>
1284
+ #include <ATen/ops/tril_indices_native.h>
1285
+ #include <ATen/ops/triplet_margin_loss_native.h>
1286
+ #include <ATen/ops/triu_native.h>
1287
+ #include <ATen/ops/triu_indices_native.h>
1288
+ #include <ATen/ops/true_divide_native.h>
1289
+ #include <ATen/ops/trunc_native.h>
1290
+ #include <ATen/ops/type_as_native.h>
1291
+ #include <ATen/ops/unbind_native.h>
1292
+ #include <ATen/ops/unbind_copy_native.h>
1293
+ #include <ATen/ops/unflatten_native.h>
1294
+ #include <ATen/ops/unflatten_dense_tensors_native.h>
1295
+ #include <ATen/ops/unfold_native.h>
1296
+ #include <ATen/ops/unfold_backward_native.h>
1297
+ #include <ATen/ops/unfold_copy_native.h>
1298
+ #include <ATen/ops/uniform_native.h>
1299
+ #include <ATen/ops/unique_consecutive_native.h>
1300
+ #include <ATen/ops/unique_dim_native.h>
1301
+ #include <ATen/ops/unique_dim_consecutive_native.h>
1302
+ #include <ATen/ops/unsafe_chunk_native.h>
1303
+ #include <ATen/ops/unsafe_split_native.h>
1304
+ #include <ATen/ops/unsafe_split_with_sizes_native.h>
1305
+ #include <ATen/ops/unsqueeze_native.h>
1306
+ #include <ATen/ops/unsqueeze_copy_native.h>
1307
+ #include <ATen/ops/upsample_bicubic2d_native.h>
1308
+ #include <ATen/ops/upsample_bicubic2d_backward_native.h>
1309
+ #include <ATen/ops/upsample_bilinear2d_native.h>
1310
+ #include <ATen/ops/upsample_bilinear2d_backward_native.h>
1311
+ #include <ATen/ops/upsample_linear1d_native.h>
1312
+ #include <ATen/ops/upsample_linear1d_backward_native.h>
1313
+ #include <ATen/ops/upsample_nearest1d_native.h>
1314
+ #include <ATen/ops/upsample_nearest1d_backward_native.h>
1315
+ #include <ATen/ops/upsample_nearest2d_native.h>
1316
+ #include <ATen/ops/upsample_nearest2d_backward_native.h>
1317
+ #include <ATen/ops/upsample_nearest3d_native.h>
1318
+ #include <ATen/ops/upsample_nearest3d_backward_native.h>
1319
+ #include <ATen/ops/upsample_trilinear3d_native.h>
1320
+ #include <ATen/ops/upsample_trilinear3d_backward_native.h>
1321
+ #include <ATen/ops/value_selecting_reduction_backward_native.h>
1322
+ #include <ATen/ops/values_native.h>
1323
+ #include <ATen/ops/values_copy_native.h>
1324
+ #include <ATen/ops/vander_native.h>
1325
+ #include <ATen/ops/var_native.h>
1326
+ #include <ATen/ops/var_mean_native.h>
1327
+ #include <ATen/ops/vdot_native.h>
1328
+ #include <ATen/ops/view_native.h>
1329
+ #include <ATen/ops/view_as_native.h>
1330
+ #include <ATen/ops/view_as_complex_native.h>
1331
+ #include <ATen/ops/view_as_complex_copy_native.h>
1332
+ #include <ATen/ops/view_as_real_native.h>
1333
+ #include <ATen/ops/view_as_real_copy_native.h>
1334
+ #include <ATen/ops/view_copy_native.h>
1335
+ #include <ATen/ops/vsplit_native.h>
1336
+ #include <ATen/ops/vstack_native.h>
1337
+ #include <ATen/ops/where_native.h>
1338
+ #include <ATen/ops/xlogy_native.h>
1339
+ #include <ATen/ops/xor_native.h>
1340
+ #include <ATen/ops/zero_native.h>
1341
+ #include <ATen/ops/zeros_native.h>
1342
+ #include <ATen/ops/zeros_like_native.h>
1343
+
1344
+
.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/MemoryOverlap.h>
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/DispatchKey.h>
5
+ #include <c10/core/DispatchKeySet.h>
6
+ #include <c10/core/MemoryFormat.h>
7
+ #include <c10/core/TensorImpl.h>
8
+ #include <c10/util/ArrayRef.h>
9
+ #include <c10/util/Exception.h>
10
+ #include <c10/util/Metaprogramming.h>
11
+ #include <c10/util/irange.h>
12
+
13
+ namespace at::native {
14
+ struct NestedTensorImpl;
15
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
16
+ int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
17
+ at::Tensor construct_nested_strides(const at::Tensor& nested_size);
18
+ at::Tensor construct_offsets(const at::Tensor& nested_size);
19
+
20
+ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
21
+ explicit NestedTensorImpl(
22
+ Storage storage,
23
+ c10::DispatchKeySet key_set,
24
+ const caffe2::TypeMeta data_type,
25
+ at::Tensor nested_sizes,
26
+ at::Tensor nested_strides,
27
+ at::Tensor storage_offsets);
28
+
29
+ explicit NestedTensorImpl(
30
+ const at::Tensor& buffer,
31
+ at::Tensor nested_sizes,
32
+ at::Tensor nested_strides,
33
+ at::Tensor storage_offsets);
34
+ // assume contiguous, `nested_strides` and `offsets`
35
+ // can be infered from `nested_sizes`
36
+ explicit NestedTensorImpl(
37
+ const at::Tensor& buffer,
38
+ const at::Tensor& nested_sizes);
39
+
40
+ // This constructor is used creating view tensors from nested tensors
41
+ explicit NestedTensorImpl(
42
+ c10::TensorImpl::ImplType impl_type,
43
+ const at::Tensor& base_tensor,
44
+ at::Tensor nested_sizes,
45
+ at::Tensor nested_strides,
46
+ at::Tensor storage_offsets);
47
+
48
+ // TODO: don't expose private implementation details like this; in
49
+ // particular, resizing this tensor will mess up our dim() and
50
+ // callers cannot fix it.
51
+ const Tensor& get_nested_sizes() const {
52
+ return nested_sizes_;
53
+ }
54
+ // TODO: don't expose private implementation details like this
55
+ const Tensor& get_nested_strides() const {
56
+ return nested_strides_;
57
+ }
58
+ const Tensor& get_storage_offsets() const {
59
+ return storage_offsets_;
60
+ }
61
+ // Returns nullopt if the ith dimension is irregular. The ith dimension
62
+ // of a NestedTensor is regular if the unbound tensors match in
63
+ // size at the (i-1)th dimension.
64
+ std::optional<int64_t> opt_size(int64_t d) const;
65
+
66
+ int64_t size(int64_t d) const {
67
+ std::optional<int64_t> optional_size = this->opt_size(d);
68
+ TORCH_CHECK(
69
+ optional_size.has_value(),
70
+ "Given dimension ",
71
+ d,
72
+ " is irregular and does not have a size.");
73
+ return *optional_size;
74
+ }
75
+ /**
76
+ * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
77
+ *
78
+ * The buffer tensor created by this function shares the same storage_impl as
79
+ * the original nested tensor, and therefore can be seen as a view.
80
+ *
81
+ * @return A newly constructed view tensor
82
+ */
83
+ at::Tensor get_buffer() const {
84
+ TORCH_CHECK(
85
+ nested_tensor_impl_is_contiguous(this),
86
+ "NestedTensor must be contiguous to get buffer.");
87
+ return get_unsafe_storage_as_tensor();
88
+ }
89
+ /**
90
+ * If possible use get_buffer() instead. This function returns the storage
91
+ * as a tensor directly, which is not safe to use in general. If using this
92
+ * function, The caller must ensure to account for nested_sizes,
93
+ * nested_strides and storage_offsets.
94
+ *
95
+ * @return A newly constructed view tensor
96
+ */
97
+ at::Tensor get_unsafe_storage_as_tensor() const {
98
+ auto buffer_key_set_ = generate_buffer_key_set();
99
+ const auto buffer_size = get_buffer_size();
100
+ auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
101
+ c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
102
+ buffer_tensor_impl->set_sizes_contiguous(
103
+ c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
104
+ return Tensor(buffer_tensor_impl);
105
+ }
106
+
107
+ size_t get_buffer_size() const {
108
+ return storage_.nbytes() / data_type_.itemsize();
109
+ }
110
+
111
+ protected:
112
+ const char* tensorimpl_type_name() const override;
113
+
114
+ // TODO: numel_custom and is_contiguous_custom can be profitably overridden
115
+ // with real implementations
116
+ int64_t numel_custom() const override;
117
+ c10::SymInt sym_numel_custom() const override;
118
+ bool is_contiguous_custom(MemoryFormat) const override;
119
+ int64_t size_custom(int64_t d) const override {
120
+ return this->size(d);
121
+ }
122
+ c10::SymInt sym_size_custom(int64_t d) const override {
123
+ return c10::SymInt{this->size(d)};
124
+ }
125
+ IntArrayRef sizes_custom() const override;
126
+ c10::SymIntArrayRef sym_sizes_custom() const override;
127
+ IntArrayRef strides_custom() const override;
128
+ c10::SymIntArrayRef sym_strides_custom() const override;
129
+
130
+ // this one is real
131
+ int64_t dim_custom() const override;
132
+
133
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
134
+ const c10::VariableVersion& version_counter,
135
+ bool allow_tensor_metadata_change) const override;
136
+
137
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
138
+ c10::VariableVersion&& version_counter,
139
+ bool allow_tensor_metadata_change) const override;
140
+
141
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
142
+ copy_tensor_metadata(
143
+ /*src_impl=*/impl.get(),
144
+ /*dest_impl=*/this,
145
+ /*version_counter=*/version_counter(),
146
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
147
+ }
148
+
149
+ private:
150
+ // Must be called after any changes to our dim() to sync the state
151
+ // to TensorImpl.
152
+ void refresh_dim();
153
+
154
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
155
+ const at::Tensor nested_sizes_, nested_strides_;
156
+ // The starting positions of the underlying tensors in contiguous buffer
157
+ // i.e. the buffer memory offsets to get the underlying tensors
158
+ // The reason to keep this metadata is that, without strong enough constraint
159
+ // it cannot be derived from `nested_sizes_`
160
+ // and `nested_strides_`:
161
+ // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
162
+ // this can happen e.g. after slicing a nested tensor
163
+ // 2. when multiple tensors share a same memory
164
+ // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
165
+ // Some strong enough constraints are:
166
+ // 1. every underlying tensor is contiguous in memory
167
+ // && nesting in ascending order
168
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
169
+ const at::Tensor storage_offsets_;
170
+ // NOTE: -1 here means the size is missing
171
+ // Optional to allow it to be computed lazily from nested.
172
+ // TODO: maybe we can remove this metadata since
173
+ // we can compute it from `nested_sizes_`
174
+ mutable std::optional<std::vector<int64_t>> opt_sizes_;
175
+
176
+ template <typename VariableVersion>
177
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
178
+ VariableVersion&& version_counter,
179
+ bool allow_tensor_metadata_change) const;
180
+
181
+ /**
182
+ * Generates a non-nested key_set from a nested tensor.
183
+ *
184
+ * For many nested tensor kernel implementations a buffer tensor
185
+ * is generated and redispatched to a non-nested kernel this function
186
+ * generates the key set used by that buffer tensor
187
+ *
188
+ * @return Appropriate key set for non-nested tensor
189
+ */
190
+ inline c10::DispatchKeySet generate_buffer_key_set() const {
191
+ auto buffer_key_set = this->key_set();
192
+ const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
193
+ // Remove nested tensor specific keys
194
+ buffer_key_set = buffer_key_set -
195
+ c10::DispatchKeySet{
196
+ c10::DispatchKey::NestedTensor,
197
+ c10::DispatchKey::AutogradNestedTensor};
198
+
199
+ // Add dense tensor specific keys
200
+ buffer_key_set =
201
+ buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
202
+ buffer_key_set = Autograd
203
+ ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
204
+ : buffer_key_set;
205
+
206
+ return buffer_key_set;
207
+ }
208
+ };
209
+
210
+ inline NestedTensorImpl* get_nested_tensor_impl_or_null(
211
+ const at::Tensor& tensor) {
212
+ if (tensor.is_nested()) {
213
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
214
+ }
215
+ return nullptr;
216
+ }
217
+
218
+ inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
219
+ TORCH_CHECK(
220
+ tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
221
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
222
+ }
223
+
224
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
225
+ int64_t ntensors = nt->size(0);
226
+ if (ntensors == 0) {
227
+ return true;
228
+ }
229
+ const Tensor &sizemat = nt->get_nested_sizes(),
230
+ &stridemat = nt->get_nested_strides();
231
+ const int64_t* offsets_ptr =
232
+ nt->get_storage_offsets().const_data_ptr<int64_t>();
233
+ int64_t orig_dim = sizemat.size(1);
234
+ // nesting scalars
235
+ if (orig_dim == 0) {
236
+ // each scalar must be contiguous
237
+ // if there is blank memory between underlying scalars
238
+ for (int64_t i = 0; i < ntensors; i++) {
239
+ if (offsets_ptr[i] != i) {
240
+ return false;
241
+ }
242
+ }
243
+ }
244
+ // nesting tensors
245
+ else {
246
+ // if any underlying tensor is non-contiguous
247
+ const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
248
+ *stridemat_ptr = stridemat.const_data_ptr<int64_t>();
249
+ for (int64_t i = 0; i < ntensors; i++) {
250
+ if (stridemat_ptr[orig_dim - 1] != 1) {
251
+ return false;
252
+ }
253
+ int64_t product = sizemat_ptr[orig_dim - 1];
254
+ for (int64_t j = orig_dim - 2; j >= 0; j--) {
255
+ if (stridemat_ptr[j] != product) {
256
+ return false;
257
+ }
258
+ product *= sizemat_ptr[j];
259
+ }
260
+ sizemat_ptr += orig_dim;
261
+ stridemat_ptr += orig_dim;
262
+ }
263
+ // if there is blank memory between underlying tensors
264
+ if (offsets_ptr[0] != 0) {
265
+ return false;
266
+ }
267
+ sizemat_ptr = sizemat.const_data_ptr<int64_t>();
268
+ stridemat_ptr = stridemat.const_data_ptr<int64_t>();
269
+ for (int64_t i = 1; i < ntensors; i++) {
270
+ if (offsets_ptr[i] !=
271
+ offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
272
+ return false;
273
+ }
274
+ sizemat_ptr += orig_dim;
275
+ stridemat_ptr += orig_dim;
276
+ }
277
+ }
278
+ // everything is fine
279
+ return true;
280
+ }
281
+
282
+ inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
283
+ return get_nested_tensor_impl(tensor)->get_nested_sizes();
284
+ }
285
+
286
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/OpMathType.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+ #include <c10/util/BFloat16.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/Float8_e4m3fn.h>
7
+ #include <c10/util/Float8_e4m3fnuz.h>
8
+ #include <c10/util/Float8_e5m2.h>
9
+ #include <c10/util/Float8_e5m2fnuz.h>
10
+ #include <c10/util/Half.h>
11
+
12
+ namespace at {
13
+
14
+ // For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
15
+ template <typename scalar_t>
16
+ struct OpMathType {
17
+ using type = scalar_t;
18
+ };
19
+ template <>
20
+ struct OpMathType<at::Half> {
21
+ using type = float;
22
+ };
23
+ template <>
24
+ struct OpMathType<at::BFloat16> {
25
+ using type = float;
26
+ };
27
+ template <>
28
+ struct OpMathType<at::Float8_e5m2> {
29
+ using type = float;
30
+ };
31
+ template <>
32
+ struct OpMathType<at::Float8_e4m3fn> {
33
+ using type = float;
34
+ };
35
+ template <>
36
+ struct OpMathType<at::Float8_e5m2fnuz> {
37
+ using type = float;
38
+ };
39
+ template <>
40
+ struct OpMathType<at::Float8_e4m3fnuz> {
41
+ using type = float;
42
+ };
43
+ template <>
44
+ struct OpMathType<c10::complex<Half>> {
45
+ using type = c10::complex<float>;
46
+ };
47
+
48
+ template <typename T>
49
+ using opmath_type = typename OpMathType<T>::type;
50
+
51
+ namespace {
52
+
53
+ inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
54
+ switch (type) {
55
+ #define DEFINE_CASE(scalar_t, TypeNum) \
56
+ case ScalarType::TypeNum: \
57
+ return CppTypeToScalarType<at::opmath_type<scalar_t>>::value;
58
+
59
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
60
+ #undef DEFINE_CASE
61
+
62
+ default:
63
+ TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
64
+ }
65
+ }
66
+
67
+ } // namespace
68
+
69
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/util/Exception.h>
3
+ #include <c10/util/string_view.h>
4
+
5
+ namespace at {
6
+
7
+ enum class padding_mode {
8
+ reflect,
9
+ replicate,
10
+ circular,
11
+ constant,
12
+ };
13
+
14
+ static inline c10::string_view padding_mode_string(padding_mode m) {
15
+ switch (m) {
16
+ case padding_mode::reflect:
17
+ return "reflect";
18
+ case padding_mode::replicate:
19
+ return "replicate";
20
+ case padding_mode::circular:
21
+ return "circular";
22
+ case padding_mode::constant:
23
+ return "constant";
24
+ }
25
+ TORCH_CHECK(false, "Invalid padding mode (", static_cast<int64_t>(m), ")");
26
+ }
27
+
28
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Config.h>
3
+ #include <c10/macros/Macros.h>
4
+ #include <functional>
5
+ #include <string>
6
+
7
+ namespace at {
8
+
9
+ inline int64_t divup(int64_t x, int64_t y) {
10
+ return (x + y - 1) / y;
11
+ }
12
+
13
+ // Called during new thread initialization
14
+ TORCH_API void init_num_threads();
15
+
16
+ // Sets the number of threads to be used in parallel region
17
+ TORCH_API void set_num_threads(int);
18
+
19
+ // Returns the maximum number of threads that may be used in a parallel region
20
+ TORCH_API int get_num_threads();
21
+
22
+ // Returns the current thread number (starting from 0)
23
+ // in the current parallel region, or 0 in the sequential region
24
+ TORCH_API int get_thread_num();
25
+
26
+ // Checks whether the code runs in parallel region
27
+ TORCH_API bool in_parallel_region();
28
+
29
+ namespace internal {
30
+
31
+ // Initialise num_threads lazily at first parallel call
32
+ inline void lazy_init_num_threads() {
33
+ thread_local bool init = false;
34
+ if (C10_UNLIKELY(!init)) {
35
+ at::init_num_threads();
36
+ init = true;
37
+ }
38
+ }
39
+
40
+ TORCH_API void set_thread_num(int);
41
+
42
+ class TORCH_API ThreadIdGuard {
43
+ public:
44
+ ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
45
+ set_thread_num(new_id);
46
+ }
47
+
48
+ ~ThreadIdGuard() {
49
+ set_thread_num(old_id_);
50
+ }
51
+
52
+ private:
53
+ int old_id_;
54
+ };
55
+
56
+ } // namespace internal
57
+
58
+ /*
59
+ parallel_for
60
+
61
+ begin: index at which to start applying user function
62
+
63
+ end: index at which to stop applying user function
64
+
65
+ grain_size: number of elements per chunk. impacts the degree of parallelization
66
+
67
+ f: user function applied in parallel to the chunks, signature:
68
+ void f(int64_t begin, int64_t end)
69
+
70
+ Warning: parallel_for does NOT copy thread local
71
+ states from the current thread to the worker threads.
72
+ This means for example that Tensor operations CANNOT be used in the
73
+ body of your function, only data pointers.
74
+ */
75
+ template <class F>
76
+ inline void parallel_for(
77
+ const int64_t begin,
78
+ const int64_t end,
79
+ const int64_t grain_size,
80
+ const F& f);
81
+
82
+ /*
83
+ parallel_reduce
84
+
85
+ begin: index at which to start applying reduction
86
+
87
+ end: index at which to stop applying reduction
88
+
89
+ grain_size: number of elements per chunk. impacts number of elements in
90
+ intermediate results tensor and degree of parallelization.
91
+
92
+ ident: identity for binary combination function sf. sf(ident, x) needs to return
93
+ x.
94
+
95
+ f: function for reduction over a chunk. f needs to be of signature scalar_t
96
+ f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
97
+
98
+ sf: function to combine two partial results. sf needs to be of signature
99
+ scalar_t sf(scalar_t x, scalar_t y)
100
+
101
+ For example, you might have a tensor of 10000 entires and want to sum together
102
+ all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
103
+ an intermediate result tensor with 4 elements. Then it will execute the function
104
+ "f" you provide and pass the beginning and end index of these chunks, so
105
+ 0-2499, 2500-4999, etc. and the combination identity. It will then write out
106
+ the result from each of these chunks into the intermediate result tensor. After
107
+ that it'll reduce the partial results from each chunk into a single number using
108
+ the combination function sf and the identity ident. For a total summation this
109
+ would be "+" and 0 respectively. This is similar to tbb's approach [1], where
110
+ you need to provide a function to accumulate a subrange, a function to combine
111
+ two partial results and an identity.
112
+
113
+ Warning: parallel_reduce does NOT copy thread local
114
+ states from the current thread to the worker threads.
115
+ This means for example that Tensor operations CANNOT be used in the
116
+ body of your function, only data pointers.
117
+
118
+ [1] https://software.intel.com/en-us/node/506154
119
+ */
120
+ template <class scalar_t, class F, class SF>
121
+ inline scalar_t parallel_reduce(
122
+ const int64_t begin,
123
+ const int64_t end,
124
+ const int64_t grain_size,
125
+ const scalar_t ident,
126
+ const F& f,
127
+ const SF& sf);
128
+
129
+ // Returns a detailed string describing parallelization settings
130
+ TORCH_API std::string get_parallel_info();
131
+
132
+ // Sets number of threads used for inter-op parallelism
133
+ TORCH_API void set_num_interop_threads(int);
134
+
135
+ // Returns the number of threads used for inter-op parallelism
136
+ TORCH_API int get_num_interop_threads();
137
+
138
+ // Launches inter-op parallel task
139
+ TORCH_API void launch(std::function<void()> func);
140
+ namespace internal {
141
+ void launch_no_thread_state(std::function<void()> fn);
142
+ } // namespace internal
143
+
144
+ // Launches intra-op parallel task
145
+ TORCH_API void intraop_launch(std::function<void()> func);
146
+
147
+ // Returns number of intra-op threads used by default
148
+ TORCH_API int intraop_default_num_threads();
149
+
150
+ } // namespace at
151
+
152
+ #if AT_PARALLEL_OPENMP
153
+ #include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
154
+ #elif AT_PARALLEL_NATIVE
155
+ #include <ATen/ParallelNative.h> // IWYU pragma: keep
156
+ #endif
157
+
158
+ #include <ATen/Parallel-inl.h> // IWYU pragma: keep
.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelFuture.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <functional>
6
+
7
+ namespace at {
8
+
9
+ // Launches intra-op parallel task, returns a future
10
+ TORCH_API c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
11
+ std::function<void()> func);
12
+
13
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/RegistrationDeclarations.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/SafePyObject.h>
4
+ #include <c10/macros/Export.h>
5
+ #include <c10/util/python_stub.h>
6
+ #include <optional>
7
+ #include <stack>
8
+ #include <string>
9
+
10
+ #include <utility>
11
+
12
+ namespace at {
13
+
14
+ namespace impl {
15
+
16
+ struct TORCH_API SavedTensorDefaultHooksTLS {
17
+ // PyObject is defined in c10/util/python_stub.h
18
+ std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
19
+
20
+ // See NOTE: [Disabling SavedTensorDefaultHooks] for context
21
+ // NOTE: [disabled_error_message invariant]
22
+ // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
23
+ // We did this for efficiency (so we didn't have to keep a separate bool
24
+ // around)
25
+ std::optional<std::string> disabled_error_message;
26
+
27
+ // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
28
+ bool is_tracing = false;
29
+ };
30
+
31
+ } // namespace impl
32
+
33
+ struct TORCH_API SavedTensorDefaultHooks {
34
+ static void push_hooks(
35
+ c10::SafePyObject pack_hook,
36
+ c10::SafePyObject unpack_hook);
37
+ static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
38
+ static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
39
+ get_hooks();
40
+ static void lazy_initialize();
41
+
42
+ static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
43
+ static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
44
+
45
+ // NOTE: [Disabling SavedTensorDefaultHooks]
46
+ // A developer of a PyTorch feature may choose to disable SavedTensorDefault
47
+ // hooks, especially if their feature does not work with it. If they are
48
+ // disabled, then the following will raise an error:
49
+ // - Attempting to push_hooks
50
+ // - calling disable(message) with a non-zero stack (hooks) size
51
+ static void disable(const std::string& error_message);
52
+ static void enable();
53
+ static bool is_enabled();
54
+ static const std::optional<std::string>& get_disabled_error_message();
55
+
56
+ // NOTE: [Deferring tensor pack/unpack hooks until runtime]
57
+ // To preserve eager semantics of pack/unpack hooks firing only once per saved
58
+ // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
59
+ // disable() would loud error at trace time, and pushing a no-op hook would
60
+ // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
61
+ // To do so, we disable these hooks during tracing. See
62
+ // https://github.com/pytorch/pytorch/issues/113263.
63
+ static bool set_tracing(bool is_tracing);
64
+ };
65
+
66
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Scalar.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Scalar.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/ScalarOps.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/Scalar.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/Functions.h>
8
+ #else
9
+ #include <ATen/ops/scalar_tensor.h>
10
+ #endif
11
+
12
+ namespace at::detail {
13
+ // When filling a number to 1-element CPU tensor, we want to skip
14
+ // everything but manipulate data ptr directly.
15
+ // Ideally this fast pass should be implemented in TensorIterator,
16
+ // but we also want to skip compute_types which in not avoidable
17
+ // in TensorIterator for now.
18
+ Tensor& scalar_fill(Tensor& self, const Scalar& value);
19
+ TORCH_API Tensor scalar_tensor_static(
20
+ const Scalar& s,
21
+ std::optional<ScalarType> dtype_opt,
22
+ std::optional<Device> device_opt);
23
+ } // namespace at::detail
24
+
25
+ // This is in the c10 namespace because we use ADL to find the functions in it.
26
+ namespace c10 {
27
+
28
+ // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
29
+ // way to implement this without going through Derived Types (which are not part
30
+ // of core).
31
+ inline at::Tensor scalar_to_tensor(
32
+ const Scalar& s,
33
+ const Device device = at::kCPU) {
34
+ // This is the fast track we have for CPU scalar tensors.
35
+ if (device == at::kCPU) {
36
+ return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
37
+ }
38
+ return at::scalar_tensor(s, at::device(device).dtype(s.type()));
39
+ }
40
+
41
+ } // namespace c10
42
+
43
+ namespace at::native {
44
+
45
+ inline Tensor wrapped_scalar_tensor(
46
+ const Scalar& scalar,
47
+ const Device device = at::kCPU) {
48
+ auto tensor = scalar_to_tensor(scalar, device);
49
+ tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
50
+ return tensor;
51
+ }
52
+
53
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/ScalarType.h ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/ATenGeneral.h> // for BC reasons
3
+ #include <c10/core/Backend.h>
4
+ #include <c10/core/ScalarType.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorImpl.h ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/TensorImpl.h>
5
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
6
+ #include <c10/util/Exception.h>
7
+ namespace at {
8
+
9
+ // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
10
+ // denoting the data: `crow_indices_`, `col_indices_` and `values_`.
11
+ // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
12
+ // that represents the compressed row indices of the CSR tensor. The
13
+ // `col_indices_` tensor is an integer tensor of shape `(nnz())`
14
+ // that explicitly stores the column indices of each value of the sparse
15
+ // tensor. The `values_` tensor can be of any pytorch-supported data type
16
+ // and has shape `(nnz())`.
17
+ //
18
+ // Since the main advantage of the CSR format over the COO format is speed of
19
+ // computation, care must be taken to facilitate smooth interfacing of
20
+ // these data structures with optimized libraries such as MKL and MAGMA.
21
+ // Since the MKL interface for pytorch currently uses indexing with int32
22
+ // type, it is important to make sure that the `crow_indices` and `col_indices`
23
+ // are of type int32 when calling MKL routines such as SPMM or SPMV.
24
+ //
25
+ // If not calling MKL, it should be alright to use 64 bit integer tensors
26
+ // for indexing.
27
+ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
28
+ Tensor crow_indices_;
29
+ Tensor col_indices_;
30
+ Tensor values_;
31
+ Layout layout_;
32
+
33
+ public:
34
+ explicit SparseCsrTensorImpl(
35
+ at::DispatchKeySet,
36
+ at::Device device,
37
+ Layout layout,
38
+ const caffe2::TypeMeta);
39
+
40
+ void resize_(int64_t nnz, IntArrayRef size);
41
+ void resize_and_clear_(
42
+ int64_t sparse_dim,
43
+ int64_t dense_dim,
44
+ IntArrayRef size);
45
+ void resize_as_sparse_compressed_tensor_(const Tensor& src);
46
+ void set_member_tensors(
47
+ const Tensor& crow_indices,
48
+ const Tensor& col_indices,
49
+ const Tensor& values,
50
+ c10::SymIntArrayRef size);
51
+ void set_member_tensors(
52
+ const Tensor& crow_indices,
53
+ const Tensor& col_indices,
54
+ const Tensor& values,
55
+ IntArrayRef size);
56
+ const Tensor& compressed_indices() const {
57
+ return crow_indices_;
58
+ }
59
+ const Tensor& plain_indices() const {
60
+ return col_indices_;
61
+ }
62
+ const Tensor& values() const {
63
+ return values_;
64
+ }
65
+ int64_t nnz() {
66
+ return col_indices_.size(-1);
67
+ }
68
+
69
+ inline int64_t batch_dim() const noexcept {
70
+ return crow_indices_.dim() - 1;
71
+ }
72
+
73
+ inline int64_t sparse_dim() const noexcept {
74
+ return 2;
75
+ }
76
+
77
+ inline int64_t dense_dim() const noexcept {
78
+ return values_.dim() - batch_dim() - block_dim() - 1;
79
+ }
80
+
81
+ private:
82
+ inline int64_t block_dim() const noexcept {
83
+ return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
84
+ }
85
+
86
+ protected:
87
+ IntArrayRef strides_custom() const override;
88
+ SymIntArrayRef sym_strides_custom() const override;
89
+ bool is_contiguous_custom(MemoryFormat) const override;
90
+
91
+ public:
92
+ void set_size(int64_t dim, int64_t new_size) override;
93
+ void set_stride(int64_t dim, int64_t new_stride) override;
94
+ void set_storage_offset(int64_t storage_offset) override;
95
+ Layout layout_impl() const override {
96
+ return layout_;
97
+ }
98
+ void set_layout(Layout layout) {
99
+ switch (layout) {
100
+ case kSparseCsr:
101
+ case kSparseCsc:
102
+ case kSparseBsr:
103
+ case kSparseBsc:
104
+ layout_ = layout;
105
+ break;
106
+ default:
107
+ TORCH_CHECK(false, "unsupported layout ", layout);
108
+ }
109
+ }
110
+
111
+ template <typename VariableVersion>
112
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
113
+ VariableVersion&& version_counter,
114
+ bool allow_tensor_metadata_change) const {
115
+ const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
116
+ c10::impl::PyInterpreter&& interpreter = nullptr;
117
+ if (mode_stack_len > 0 &&
118
+ !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
119
+ const auto& cur_torch_dispatch_mode_state =
120
+ c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
121
+ interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
122
+ } else if (
123
+ key_set_.has(DispatchKey::Python) &&
124
+ !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
125
+ interpreter = pyobj_slot_.load_pyobj_interpreter();
126
+ } else {
127
+ // otherwise just copy the SparseTensorImpl and not the PyObject.
128
+ auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
129
+ key_set(), device(), layout_impl(), dtype());
130
+ copy_tensor_metadata(
131
+ /*src_sparse_impl=*/this,
132
+ /*dest_sparse_impl=*/impl.get(),
133
+ /*version_counter=*/version_counter,
134
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
135
+ impl->refresh_numel();
136
+ return impl;
137
+ }
138
+ auto r = interpreter->detach(this);
139
+ r->set_version_counter(std::forward<VariableVersion>(version_counter));
140
+ r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
141
+ return r;
142
+ }
143
+
144
+ /**
145
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
146
+ *
147
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
148
+ * see NOTE [ TensorImpl Shallow-Copying ].
149
+ */
150
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
151
+ const c10::VariableVersion& version_counter,
152
+ bool allow_tensor_metadata_change) const override {
153
+ return shallow_copy_and_detach_core(
154
+ version_counter, allow_tensor_metadata_change);
155
+ }
156
+
157
+ /**
158
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
159
+ *
160
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
161
+ * see NOTE [ TensorImpl Shallow-Copying ].
162
+ */
163
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
164
+ c10::VariableVersion&& version_counter,
165
+ bool allow_tensor_metadata_change) const override {
166
+ return shallow_copy_and_detach_core(
167
+ std::move(version_counter), allow_tensor_metadata_change);
168
+ }
169
+
170
+ private:
171
+ explicit SparseCsrTensorImpl(
172
+ at::DispatchKeySet key_set,
173
+ const caffe2::TypeMeta data_type,
174
+ at::Tensor crow_indices,
175
+ at::Tensor col_indices,
176
+ at::Tensor values,
177
+ at::Layout layout);
178
+
179
+ const char* tensorimpl_type_name() const override;
180
+
181
+ /**
182
+ * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
183
+ * storage_offset) from one TensorImpl to another TensorImpl.
184
+ *
185
+ * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
186
+ * [ TensorImpl Shallow-Copying ].
187
+ */
188
+ static void copy_tensor_metadata(
189
+ const SparseCsrTensorImpl* src_sparse_impl,
190
+ SparseCsrTensorImpl* dest_sparse_impl,
191
+ c10::VariableVersion version_counter,
192
+ bool allow_tensor_metadata_change) {
193
+ TensorImpl::copy_tensor_metadata(
194
+ src_sparse_impl,
195
+ dest_sparse_impl,
196
+ std::move(version_counter),
197
+ allow_tensor_metadata_change);
198
+
199
+ // Sparse-specific fields
200
+ dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
201
+ dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
202
+ dest_sparse_impl->values_ = src_sparse_impl->values();
203
+ dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
204
+ }
205
+ };
206
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/SparseCsrTensorImpl.h>
4
+ #include <ATen/SparseTensorImpl.h>
5
+ #include <ATen/core/Tensor.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/Functions.h>
9
+ #include <ATen/NativeFunctions.h>
10
+ #include <ATen/Operators.h>
11
+ #else
12
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
13
+ #include <ATen/ops/resize_as_sparse_native.h>
14
+ #endif
15
+
16
+ #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
17
+ [&] { \
18
+ const auto& the_layout = LAYOUT; \
19
+ switch (the_layout) { \
20
+ case kSparseCsr: \
21
+ case kSparseCsc: \
22
+ case kSparseBsr: \
23
+ case kSparseBsc: \
24
+ return __VA_ARGS__(); \
25
+ default: \
26
+ AT_ERROR( \
27
+ NAME, \
28
+ " expected sparse compressed tensor layout but got ", \
29
+ the_layout); \
30
+ } \
31
+ }()
32
+
33
+ #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
34
+ LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
35
+ [&]() { \
36
+ const auto& the_layout = LAYOUT; \
37
+ switch (the_layout) { \
38
+ case kSparseCsr: \
39
+ case kSparseBsr: \
40
+ return (ROW_DIM_ACTION)(); \
41
+ case kSparseCsc: \
42
+ case kSparseBsc: \
43
+ return (COLUMN_DIM_ACTION)(); \
44
+ default: \
45
+ AT_ERROR( \
46
+ NAME, \
47
+ " expected sparse compressed tensor layout but got ", \
48
+ the_layout); \
49
+ } \
50
+ }()
51
+
52
+ #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
53
+ LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
54
+ [&]() { \
55
+ const auto& the_layout = LAYOUT; \
56
+ switch (the_layout) { \
57
+ case kSparseCsr: \
58
+ case kSparseCsc: \
59
+ return (NO_BLOCK_ACTION)(); \
60
+ case kSparseBsr: \
61
+ case kSparseBsc: \
62
+ return (BLOCK_ACTION)(); \
63
+ default: \
64
+ AT_ERROR( \
65
+ NAME, \
66
+ " expected sparse compressed tensor layout but got ", \
67
+ the_layout); \
68
+ } \
69
+ }()
70
+
71
+ #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
72
+ LAYOUT, NAME, ROW_DIM_ACTION) \
73
+ [&]() { \
74
+ const auto& the_layout = LAYOUT; \
75
+ switch (the_layout) { \
76
+ case kSparseCsr: \
77
+ case kSparseBsr: \
78
+ return (ROW_DIM_ACTION)(); \
79
+ default: \
80
+ AT_ERROR( \
81
+ NAME, \
82
+ " expected sparse row compressed tensor layout but got ", \
83
+ the_layout); \
84
+ } \
85
+ }()
86
+
87
+ #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
88
+ LAYOUT, NAME, COL_DIM_ACTION) \
89
+ [&]() { \
90
+ const auto& the_layout = LAYOUT; \
91
+ switch (the_layout) { \
92
+ case kSparseCsc: \
93
+ case kSparseBsc: \
94
+ return (COL_DIM_ACTION)(); \
95
+ default: \
96
+ AT_ERROR( \
97
+ NAME, \
98
+ " expected sparse column compressed tensor layout but got ", \
99
+ the_layout); \
100
+ } \
101
+ }()
102
+
103
+ #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
104
+ [&]() { \
105
+ const auto& the_layout = LAYOUT; \
106
+ switch (the_layout) { \
107
+ case kSparseCsr: \
108
+ case kSparseCsc: \
109
+ return (ACTION)(); \
110
+ default: \
111
+ AT_ERROR( \
112
+ NAME, \
113
+ " expected sparse compressed (non-block) tensor layout but got ", \
114
+ the_layout); \
115
+ } \
116
+ }()
117
+
118
+ #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
119
+ [&]() { \
120
+ const auto& the_layout = LAYOUT; \
121
+ switch (the_layout) { \
122
+ case kSparseBsr: \
123
+ case kSparseBsc: \
124
+ return (ACTION)(); \
125
+ default: \
126
+ AT_ERROR( \
127
+ NAME, \
128
+ " expected sparse compressed block tensor layout but got ", \
129
+ the_layout); \
130
+ } \
131
+ }()
132
+
133
+ #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
134
+ AT_DISPATCH_SWITCH( \
135
+ TYPE, \
136
+ NAME, \
137
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
138
+ kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
139
+
140
+ namespace at::sparse_csr {
141
+
142
+ // Implements RAII object to manage checking sparse tensor invariants:
143
+ class CheckSparseTensorInvariants {
144
+ bool old_state;
145
+
146
+ public:
147
+ CheckSparseTensorInvariants(bool state) {
148
+ old_state = at::globalContext().checkSparseTensorInvariants();
149
+ at::globalContext().setCheckSparseTensorInvariants(state);
150
+ }
151
+
152
+ ~CheckSparseTensorInvariants() {
153
+ at::globalContext().setCheckSparseTensorInvariants(old_state);
154
+ }
155
+ };
156
+
157
+ using SparseCsrTensor = Tensor;
158
+
159
+ inline bool is_sparse_compressed(const Layout& layout) {
160
+ switch (layout) {
161
+ case kSparseCsr:
162
+ case kSparseCsc:
163
+ case kSparseBsr:
164
+ case kSparseBsc:
165
+ return true;
166
+ default:;
167
+ }
168
+ return false;
169
+ }
170
+
171
+ inline bool is_sparse_compressed(const Tensor& self) {
172
+ return is_sparse_compressed(self.layout());
173
+ }
174
+
175
+ inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
176
+ AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
177
+ self.layout(), "get_sparse_csr_impl", [&] {});
178
+ return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
179
+ }
180
+
181
+ inline std::string layoutToString(
182
+ Layout layout,
183
+ bool upper = false,
184
+ bool lower = false) {
185
+ switch (layout) {
186
+ case kSparseCsr:
187
+ return (upper ? "CSR" : (lower ? "csr" : "Csr"));
188
+ case kSparseCsc:
189
+ return (upper ? "CSC" : (lower ? "csc" : "Csc"));
190
+ case kSparseBsr:
191
+ return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
192
+ case kSparseBsc:
193
+ return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
194
+ default:
195
+ TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
196
+ return "";
197
+ }
198
+ }
199
+
200
+ inline bool isCompressedRow(Layout layout) {
201
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
202
+ layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
203
+ }
204
+
205
+ inline bool isCompressedColumn(Layout layout) {
206
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
207
+ layout,
208
+ "isCompressedColumn",
209
+ [&] { return false; },
210
+ [&] { return true; });
211
+ }
212
+
213
+ inline std::string compressedIndicesName(Layout layout) {
214
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
215
+ layout,
216
+ "compressedIndicesName",
217
+ [&] { return "crow_indices"; },
218
+ [&] { return "ccol_indices"; });
219
+ }
220
+
221
+ inline std::string plainIndicesName(Layout layout) {
222
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
223
+ layout,
224
+ "plainIndicesName",
225
+ [&] { return "col_indices"; },
226
+ [&] { return "row_indices"; });
227
+ }
228
+
229
+ inline std::string compressedDimName(Layout layout) {
230
+ switch (layout) {
231
+ case kSparseCsr:
232
+ return "row";
233
+ case kSparseCsc:
234
+ return "column";
235
+ case kSparseBsr:
236
+ return "row block";
237
+ case kSparseBsc:
238
+ return "column block";
239
+ default:
240
+ TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
241
+ return "";
242
+ }
243
+ }
244
+
245
+ inline std::string plainDimName(Layout layout) {
246
+ switch (layout) {
247
+ case kSparseCsr:
248
+ return "column";
249
+ case kSparseCsc:
250
+ return "row";
251
+ case kSparseBsr:
252
+ return "column block";
253
+ case kSparseBsc:
254
+ return "row block";
255
+ default:
256
+ TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
257
+ return "";
258
+ }
259
+ }
260
+
261
+ inline size_t rowDimension(Layout layout, IntArrayRef size) {
262
+ return size.size() - (isCompressedRow(layout) ? 2 : 1);
263
+ }
264
+
265
+ inline size_t columnDimension(Layout layout, IntArrayRef size) {
266
+ return size.size() - (isCompressedColumn(layout) ? 2 : 1);
267
+ }
268
+
269
+ inline size_t compressedDimension(
270
+ Layout layout,
271
+ IntArrayRef size,
272
+ size_t dense_ndim = 0) {
273
+ return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
274
+ }
275
+
276
+ inline size_t plainDimension(
277
+ Layout layout,
278
+ IntArrayRef size,
279
+ size_t dense_ndim = 0) {
280
+ return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
281
+ }
282
+
283
+ inline int64_t numBatchDimensions(Tensor const& self) {
284
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
285
+ self.layout(),
286
+ "numBatchDimensions",
287
+ [&self] { return self.crow_indices().dim() - 1; },
288
+ [&self] { return self.ccol_indices().dim() - 1; });
289
+ }
290
+
291
+ inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
292
+ return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
293
+ self.layout(),
294
+ "getCompressedPlainIndices",
295
+ [&self] {
296
+ return std::make_pair(self.crow_indices(), self.col_indices());
297
+ },
298
+ [&self] {
299
+ return std::make_pair(self.ccol_indices(), self.row_indices());
300
+ });
301
+ }
302
+
303
+ inline ScalarType getIndexDtype(Tensor const& self) {
304
+ switch (self.layout()) {
305
+ case kSparseCsr:
306
+ case kSparseBsr:
307
+ return self.crow_indices().scalar_type();
308
+ case kSparseCsc:
309
+ case kSparseBsc:
310
+ return self.ccol_indices().scalar_type();
311
+ case kSparse:
312
+ return self._indices().scalar_type();
313
+ default:
314
+ return ScalarType::Long;
315
+ }
316
+ }
317
+
318
+ inline Layout flip_compressed_layout(Layout layout) {
319
+ switch (layout) {
320
+ case kSparseCsr:
321
+ return kSparseCsc;
322
+ case kSparseCsc:
323
+ return kSparseCsr;
324
+ case kSparseBsr:
325
+ return kSparseBsc;
326
+ case kSparseBsc:
327
+ return kSparseBsr;
328
+ default:
329
+ TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
330
+ return kSparseCsr;
331
+ }
332
+ }
333
+
334
+ inline DimVector getBlockSize(Tensor const& self) {
335
+ int64_t n_batch = numBatchDimensions(self);
336
+ return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
337
+ }
338
+
339
+ inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
340
+ if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
341
+ int64_t n_batch = numBatchDimensions(self);
342
+ return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
343
+ } else {
344
+ return {};
345
+ }
346
+ }
347
+
348
+ template <typename binary_op_t, typename binary_op_out_t>
349
+ inline bool only_sparse_compressed_binary_op_trivial_cases(
350
+ const Tensor& self,
351
+ const Tensor& other,
352
+ const Scalar& alpha,
353
+ Tensor& out,
354
+ const binary_op_t& binary_op,
355
+ const binary_op_out_t& binary_op_out) {
356
+ // Only sparse compressed! Just like the name says :)
357
+ TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
358
+ TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
359
+ TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
360
+
361
+ // Bypass BLAS if there are matches in (self, other, out)
362
+ if (self.is_same(out) && self.is_same(other)) {
363
+ binary_op_out(self.values(), other.values(), alpha);
364
+ return true;
365
+ }
366
+ if (self.is_same(other)) {
367
+ auto [compressed_indices, plain_indices] =
368
+ at::sparse_csr::getCompressedPlainIndices(self);
369
+ static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
370
+ ->set_member_tensors(
371
+ compressed_indices,
372
+ plain_indices,
373
+ binary_op(self.values(), other.values(), alpha),
374
+ self.sizes());
375
+ return true;
376
+ }
377
+ return false;
378
+ }
379
+
380
+ inline bool only_sparse_compressed_add_trivial_cases(
381
+ const Tensor& self,
382
+ const Tensor& other,
383
+ const Scalar& alpha,
384
+ Tensor& out) {
385
+ return only_sparse_compressed_binary_op_trivial_cases(
386
+ self,
387
+ other,
388
+ alpha,
389
+ out,
390
+ [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
391
+ return v1.add(v2, alpha);
392
+ },
393
+ [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
394
+ return v1.add_(v2, alpha);
395
+ });
396
+ }
397
+
398
+ inline Tensor to_type(const Tensor& input, ScalarType dtype) {
399
+ auto [compressed_indices, plain_indices] =
400
+ at::sparse_csr::getCompressedPlainIndices(input);
401
+ return at::_sparse_compressed_tensor_unsafe(
402
+ compressed_indices,
403
+ plain_indices,
404
+ std::move(input.values()).to(dtype),
405
+ input.sizes(),
406
+ dtype,
407
+ input.layout(),
408
+ input.device(),
409
+ input.options().pinned_memory_opt());
410
+ }
411
+
412
+ template <typename acc_t, typename scalar_t>
413
+ inline std::tuple<Tensor, Tensor> create_acc_buffer(
414
+ TensorOptions option,
415
+ ScalarType type,
416
+ int64_t nnz = -1) {
417
+ Tensor new_values, new_values_acc;
418
+ constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
419
+ bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
420
+ if constexpr (need_acc) {
421
+ auto acc_dtype = CppTypeToScalarType<acc_t>::value;
422
+ new_values_acc = at::empty({}, option.dtype(acc_dtype));
423
+ new_values = is_integral ? new_values_acc : at::empty({}, option);
424
+ } else {
425
+ new_values = new_values_acc = at::empty({}, option);
426
+ }
427
+ if (nnz != -1) {
428
+ return std::make_tuple(
429
+ new_values.resize_(nnz), new_values_acc.resize_(nnz));
430
+ } else {
431
+ return std::make_tuple(new_values, new_values_acc);
432
+ }
433
+ }
434
+
435
+ inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
436
+ if (!new_values_acc.is_same(new_values)) {
437
+ new_values.copy_(new_values_acc);
438
+ }
439
+ }
440
+
441
+ } // namespace at::sparse_csr
.venv/lib/python3.11/site-packages/torch/include/ATen/Storage.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/core/Storage.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/Tensor.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>