koichi12 commited on
Commit
f52c26c
·
verified ·
1 Parent(s): 2fabb86

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h +28 -0
  2. .venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h +173 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h +226 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h +46 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h +13 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h +34 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h +263 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h +449 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h +14 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h +20 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h +14 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h +229 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h +444 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h +394 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h +518 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h +153 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h +396 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdagrad.h +20 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdam.h +27 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedSGD.h +21 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h +105 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h +16 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h +41 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h +160 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h +46 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h +17 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h +623 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h +157 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h +97 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h +26 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h +12 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h +16 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h +40 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h +48 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h +173 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h +75 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h +128 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h +544 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h +28 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h +88 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h +190 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h +49 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h +94 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h +55 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h +142 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h +52 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h +105 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h +57 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h +130 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h +48 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/core/ATen_fwd.h>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+
9
+ namespace native {
10
+
11
+ using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
12
+ TensorList,
13
+ Tensor&,
14
+ const Tensor&);
15
+
16
+ using _amp_update_scale_cpu__fn = Tensor& (*)(
17
+ Tensor&,
18
+ Tensor&,
19
+ const Tensor&,
20
+ double,
21
+ double,
22
+ int64_t);
23
+
24
+ DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
25
+ DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
26
+
27
+ } // namespace native
28
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/TypeProperties.h>
5
+ #include <ATen/ScalarOps.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/NativeFunctions.h>
9
+ #else
10
+ #include <ATen/ops/result_type.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ // original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
16
+ // the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
17
+ // match, will change them to be a common super type so comparisons are done between the same types.
18
+ // For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
19
+ // corresponding raw_* version should be used since it was already contiguous of the right type.
20
+ inline void searchsorted_maybe_trim_input_tensors(
21
+ Tensor& trimmed_input,
22
+ Tensor& trimmed_boundaries,
23
+ Tensor& trimmed_sorter,
24
+ const Tensor& raw_input,
25
+ const Tensor& raw_boundaries,
26
+ const Tensor& raw_sorter) {
27
+ bool in_is_contiguous = raw_input.is_contiguous();
28
+ bool bd_is_contiguous = raw_boundaries.is_contiguous();
29
+ bool sort_is_contiguous = raw_sorter.is_contiguous();
30
+
31
+ if (!in_is_contiguous) {
32
+ TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
33
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
34
+ "tensor if possible. This message will only appear once per program.");
35
+ trimmed_input = raw_input.contiguous();
36
+ }
37
+ if (!bd_is_contiguous) {
38
+ TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
39
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
40
+ "tensor if possible. This message will only appear once per program.");
41
+ trimmed_boundaries = raw_boundaries.contiguous();
42
+ }
43
+ if (!sort_is_contiguous) {
44
+ TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
45
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
46
+ "tensor if possible. This message will only appear once per program.");
47
+ trimmed_sorter = raw_sorter.contiguous();
48
+ }
49
+ if (raw_input.dtype() != raw_boundaries.dtype()) {
50
+ at::native::ResultTypeState state = {};
51
+ state = at::native::update_result_type_state(raw_boundaries, state);
52
+ state = at::native::update_result_type_state(raw_input, state);
53
+ ScalarType common_stype = at::native::result_type(state);
54
+
55
+ TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
56
+ if (common_stype != raw_input.scalar_type()) {
57
+ trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
58
+ }
59
+ if (common_stype != raw_boundaries.scalar_type()) {
60
+ trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
61
+ }
62
+ }
63
+ }
64
+
65
+ /* unused but needed for internal jagged tensor class */
66
+ inline void searchsorted_maybe_trim_input_tensors(
67
+ Tensor& trimmed_input,
68
+ Tensor& trimmed_boundaries,
69
+ const Tensor& raw_input,
70
+ const Tensor& raw_boundaries) {
71
+ Tensor trimmed_sorter;
72
+ Tensor raw_sorter;
73
+ return searchsorted_maybe_trim_input_tensors(
74
+ trimmed_input,
75
+ trimmed_boundaries,
76
+ trimmed_sorter,
77
+ raw_input,
78
+ raw_boundaries,
79
+ raw_sorter);
80
+ }
81
+
82
+ inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
83
+ if (boundaries.dim() != input.dim()) {
84
+ return false;
85
+ }
86
+ const auto& dims_bd = boundaries.sizes();
87
+ const auto& dims_in = input.sizes();
88
+ for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
89
+ if (dims_bd[dim] != dims_in[dim]) {
90
+ return false;
91
+ }
92
+ }
93
+ return true;
94
+ }
95
+
96
+ inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
97
+ auto tensor = c10::scalar_to_tensor(scalar, device);
98
+ // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
99
+ // So we have the same type promotion rules as binary operations.
100
+ tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
101
+ return tensor;
102
+ }
103
+
104
+ inline void searchsorted_pre_check(
105
+ const Tensor& boundaries,
106
+ const Tensor& input,
107
+ const Tensor& output,
108
+ const bool out_int32,
109
+ const bool right,
110
+ const std::optional<c10::string_view> side_opt,
111
+ const Tensor& sorter) {
112
+ if (side_opt) {
113
+ const c10::string_view side = *side_opt;
114
+ TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
115
+ "got ", side);
116
+
117
+ // assume the user has not explicitly set (right=False, side="right")
118
+ TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
119
+ "of ", side, " while right was True");
120
+ }
121
+
122
+ TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
123
+ "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
124
+ "tensor device type ", input.device());
125
+
126
+ if (sorter.defined()) {
127
+ TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
128
+ "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
129
+ "device type ", boundaries.device());
130
+
131
+ TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
132
+ "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
133
+
134
+ TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
135
+ "dtype but got dtype ", sorter.scalar_type());
136
+
137
+ if (sorter.numel() > 0) {
138
+ auto minmax = sorter.aminmax();
139
+ int64_t vmin = std::get<0>(minmax).item().toLong();
140
+ int64_t vmax = std::get<1>(minmax).item().toLong();
141
+ TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
142
+ }
143
+ }
144
+
145
+ TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
146
+ "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
147
+ "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
148
+ input.numel(), ")");
149
+
150
+ TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
151
+ "got 0 dimension");
152
+
153
+ TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
154
+ "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
155
+ "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
156
+ input.sizes());
157
+
158
+ ScalarType output_dtype = output.scalar_type();
159
+ TORCH_CHECK(
160
+ (output_dtype == ScalarType::Long && !out_int32) ||
161
+ (output_dtype == ScalarType::Int && out_int32),
162
+ "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
163
+ "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
164
+ " and out_int32 flag is ", (out_int32 ? "True" : "False"));
165
+
166
+ if (out_int32) {
167
+ TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
168
+ "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
169
+ boundaries.sizes().back());
170
+ }
171
+ }
172
+
173
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/OpMathType.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/native/TransposeType.h>
6
+ #include <c10/util/complex.h>
7
+ #include <c10/core/ScalarType.h>
8
+ #include <c10/core/Scalar.h>
9
+
10
+
11
+ namespace at::native::cpublas {
12
+
13
+ namespace internal {
14
+ void normalize_last_dims(
15
+ TransposeType transa, TransposeType transb,
16
+ int64_t m, int64_t n, int64_t k,
17
+ int64_t *lda, int64_t *ldb, int64_t *ldc);
18
+ } // namespace internal
19
+
20
+ using gemm_fn = void(*)(
21
+ at::ScalarType type,
22
+ TransposeType transa, TransposeType transb,
23
+ int64_t m, int64_t n, int64_t k,
24
+ const Scalar& alpha,
25
+ const void *a, int64_t lda,
26
+ const void *b, int64_t ldb,
27
+ const Scalar& beta,
28
+ void *c, int64_t ldc);
29
+
30
+ DECLARE_DISPATCH(gemm_fn, gemm_stub);
31
+
32
+ template <typename scalar_t>
33
+ void gemm(
34
+ TransposeType transa, TransposeType transb,
35
+ int64_t m, int64_t n, int64_t k,
36
+ at::opmath_type<scalar_t> alpha,
37
+ const scalar_t *a, int64_t lda,
38
+ const scalar_t *b, int64_t ldb,
39
+ at::opmath_type<scalar_t> beta,
40
+ scalar_t *c, int64_t ldc) {
41
+ internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
42
+ gemm_stub(
43
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
44
+ transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
45
+ }
46
+
47
+ void gemm(
48
+ TransposeType transa, TransposeType transb,
49
+ int64_t m, int64_t n, int64_t k,
50
+ double alpha,
51
+ const double *a, int64_t lda,
52
+ const double *b, int64_t ldb,
53
+ double beta,
54
+ double *c, int64_t ldc);
55
+
56
+ void gemm(
57
+ TransposeType transa, TransposeType transb,
58
+ int64_t m, int64_t n, int64_t k,
59
+ float alpha,
60
+ const float *a, int64_t lda,
61
+ const float *b, int64_t ldb,
62
+ float beta,
63
+ float *c, int64_t ldc);
64
+
65
+ void gemm(
66
+ TransposeType transa, TransposeType transb,
67
+ int64_t m, int64_t n, int64_t k,
68
+ float alpha,
69
+ const at::BFloat16 *a, int64_t lda,
70
+ const at::BFloat16 *b, int64_t ldb,
71
+ float beta,
72
+ at::BFloat16 *c, int64_t ldc);
73
+
74
+ void gemm(
75
+ TransposeType transa, TransposeType transb,
76
+ int64_t m, int64_t n, int64_t k,
77
+ const float alpha,
78
+ const at::BFloat16 *a, int64_t lda,
79
+ const at::BFloat16 *b, int64_t ldb,
80
+ const float beta,
81
+ float *c, int64_t ldc);
82
+
83
+ void gemm(
84
+ TransposeType transa, TransposeType transb,
85
+ int64_t m, int64_t n, int64_t k,
86
+ float alpha,
87
+ const at::Half *a, int64_t lda,
88
+ const at::Half *b, int64_t ldb,
89
+ float beta,
90
+ at::Half *c, int64_t ldc);
91
+
92
+ void gemm(
93
+ TransposeType transa, TransposeType transb,
94
+ int64_t m, int64_t n, int64_t k,
95
+ const float alpha,
96
+ const at::Half *a, int64_t lda,
97
+ const at::Half *b, int64_t ldb,
98
+ const float beta,
99
+ float *c, int64_t ldc);
100
+
101
+ void gemm(
102
+ TransposeType transa, TransposeType transb,
103
+ int64_t m, int64_t n, int64_t k,
104
+ c10::complex<double> alpha,
105
+ const c10::complex<double> *a, int64_t lda,
106
+ const c10::complex<double> *b, int64_t ldb,
107
+ c10::complex<double> beta,
108
+ c10::complex<double> *c, int64_t ldc);
109
+
110
+ void gemm(
111
+ TransposeType transa, TransposeType transb,
112
+ int64_t m, int64_t n, int64_t k,
113
+ c10::complex<float> alpha,
114
+ const c10::complex<float> *a, int64_t lda,
115
+ const c10::complex<float> *b, int64_t ldb,
116
+ c10::complex<float> beta,
117
+ c10::complex<float> *c, int64_t ldc);
118
+
119
+ void gemm(
120
+ TransposeType transa, TransposeType transb,
121
+ int64_t m, int64_t n, int64_t k,
122
+ int64_t alpha,
123
+ const int64_t *a, int64_t lda,
124
+ const int64_t *b, int64_t ldb,
125
+ int64_t beta,
126
+ int64_t *c, int64_t ldc);
127
+
128
+ template <typename scalar_t>
129
+ void gemm_batched(
130
+ TransposeType transa, TransposeType transb,
131
+ int64_t batch_size, int64_t m, int64_t n, int64_t k,
132
+ scalar_t alpha,
133
+ const scalar_t * const *a, int64_t lda,
134
+ const scalar_t * const *b, int64_t ldb,
135
+ const scalar_t beta,
136
+ scalar_t * const *c, int64_t ldc);
137
+
138
+ template <typename scalar_t>
139
+ void gemm_batched_with_stride(
140
+ TransposeType transa, TransposeType transb,
141
+ int64_t batch_size, int64_t m, int64_t n, int64_t k,
142
+ scalar_t alpha,
143
+ const scalar_t *a, int64_t lda, int64_t batch_stride_a,
144
+ const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
145
+ scalar_t beta,
146
+ scalar_t *c, int64_t ldc, int64_t batch_stride_c);
147
+
148
+ using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
149
+
150
+ DECLARE_DISPATCH(axpy_fn, axpy_stub);
151
+
152
+ template<typename scalar_t>
153
+ void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
154
+ if(n == 1)
155
+ {
156
+ incx = 1;
157
+ incy = 1;
158
+ }
159
+ axpy_stub(
160
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
161
+ n, a, x, incx, y, incy);
162
+ }
163
+
164
+ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
165
+ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
166
+ void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
167
+ void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
168
+
169
+ using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
170
+
171
+ DECLARE_DISPATCH(copy_fn, copy_stub);
172
+
173
+ template<typename scalar_t>
174
+ void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
175
+ if(n == 1)
176
+ {
177
+ incx = 1;
178
+ incy = 1;
179
+ }
180
+ copy_stub(
181
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
182
+ n, x, incx, y, incy);
183
+ }
184
+
185
+ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
186
+ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
187
+ void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
188
+ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
189
+
190
+ // Batch-reduce GEMM
191
+ // Operates by the following formula:
192
+ // C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
193
+ // A Base pointer to a tensor A.
194
+ // B Base pointer to a tensor B.
195
+ // C Pointer to a tensor C (accumulation buffer).
196
+ TORCH_API void brgemm(
197
+ int64_t M,
198
+ int64_t N,
199
+ int64_t K,
200
+ int64_t ld_a,
201
+ int64_t ld_b,
202
+ int64_t ld_c,
203
+ const float alpha,
204
+ const float beta,
205
+ const at::Half* A,
206
+ const at::Half* B,
207
+ float* C);
208
+
209
+ // Release brgemm hardware context
210
+ void brgemm_release();
211
+
212
+ // Pack B matrix to get better performance if needed
213
+ void pack(
214
+ int64_t K,
215
+ int64_t N,
216
+ int64_t ld_in,
217
+ int64_t ld_out,
218
+ ScalarType dt_in,
219
+ ScalarType dt_out,
220
+ const void* in,
221
+ void* out);
222
+
223
+ // Whether pack is needed in the platform.
224
+ bool need_pack(ScalarType dt_in);
225
+
226
+ } // namespace at::native::cpublas
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue.h>
4
+ #include <ATen/core/stack.h>
5
+ #include <ATen/core/boxing/KernelFunction.h>
6
+ #include <ATen/core/dispatch/Dispatcher.h>
7
+ #include <c10/util/Metaprogramming.h>
8
+ #include <torch/library.h>
9
+
10
+ namespace at::native {
11
+
12
+ // This function implements a boxed fallback to CPU.
13
+ // External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
14
+ TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
15
+ c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
16
+
17
+ // This is a helper function that backends can use to directly call their boxed CPU fallback
18
+ // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
19
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
20
+ struct _call_fallback_fn final {};
21
+
22
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
23
+ struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
24
+ static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
25
+ auto op = c10::Dispatcher::singleton()
26
+ // TODO: figure out how to make compiler happy without dynamic casts
27
+ .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
28
+ //.findSchemaOrThrow("a", "b")
29
+ .typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
30
+ return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
31
+ c10::BoxedKernel::makeFromFunction<fallback_fn>(),
32
+ op,
33
+ c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
34
+ // TODO: get std::forward<> to work
35
+ args...
36
+ );
37
+ }
38
+ };
39
+
40
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
41
+ using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
42
+
43
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
44
+ using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
45
+
46
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Export.h>
3
+ #include <limits>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
12
+
13
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/CompositeRandomAccessorCommon.h>
4
+
5
+ namespace at::native {
6
+
7
+ struct TupleInfoCPU {
8
+ template <typename ...Types>
9
+ using tuple = std::tuple<Types...>;
10
+
11
+ template <typename ...Types>
12
+ static constexpr auto tie(Types&... args) noexcept {
13
+ return std::tie(args...);
14
+ }
15
+ };
16
+
17
+ template <typename KeyAccessor, typename ValueAccessor>
18
+ using CompositeRandomAccessorCPU =
19
+ CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
20
+
21
+ template <typename Values, typename References>
22
+ void swap(
23
+ references_holder<Values, References> rh1,
24
+ references_holder<Values, References> rh2
25
+ ) {
26
+ return std::swap(rh1.data(), rh2.data());
27
+ }
28
+
29
+ template <int N, typename Values, typename References>
30
+ auto get(references_holder<Values, References> rh) -> decltype(std::get<N>(rh.data())) {
31
+ return std::get<N>(rh.data());
32
+ }
33
+
34
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <utility>
2
+
3
+ #pragma once
4
+
5
+ namespace at::native {
6
+
7
+ namespace {
8
+
9
+ // operator_brackets_proxy is used in
10
+ // CompositeRandomAccessor in place of operator[].
11
+ // For some iterators, references returned by operator[]
12
+ // could become invalid, operator_brackets_proxy tries to
13
+ // resolve that by making accessor[n] to be equivalent to
14
+ // *(accessor + n).
15
+ template <typename Accessor>
16
+ class operator_brackets_proxy {
17
+ using reference = typename std::iterator_traits<Accessor>::reference;
18
+ using value_type = typename std::iterator_traits<Accessor>::value_type;
19
+
20
+ public:
21
+ C10_HOST_DEVICE
22
+ operator_brackets_proxy(Accessor const& accessor)
23
+ : accessor(accessor)
24
+ {}
25
+
26
+ C10_HOST_DEVICE
27
+ operator reference() {
28
+ return *accessor;
29
+ }
30
+
31
+ C10_HOST_DEVICE
32
+ reference operator*() {
33
+ return *accessor;
34
+ }
35
+
36
+ C10_HOST_DEVICE
37
+ operator_brackets_proxy& operator=(value_type const& val) {
38
+ *accessor = val;
39
+ return *this;
40
+ }
41
+
42
+ private:
43
+ Accessor accessor;
44
+ };
45
+
46
+ }
47
+
48
+ // references_holder is used as a surrogate for the
49
+ // references type from std::iterator_traits in CompositeRandomAccessor.
50
+ // It is assumed in CompositeRandomAccessor that
51
+ // References = tuple<Types&...>,
52
+ // Values = tuple<Types...> by default,
53
+ // but they could be anything as long as References could be
54
+ // cast to Values.
55
+ // If you plan to use it with STL, for example, you will need to
56
+ // define 'swap` and `get`(aka std::get) methods.
57
+ template <typename Values, typename References>
58
+ class references_holder {
59
+ public:
60
+ using values = Values;
61
+ using references = References;
62
+
63
+ C10_HOST_DEVICE
64
+ references_holder(references refs)
65
+ : refs{std::move(refs)}
66
+ {}
67
+
68
+ C10_HOST_DEVICE
69
+ operator references() {
70
+ return refs;
71
+ }
72
+
73
+ C10_HOST_DEVICE
74
+ operator values() {
75
+ return refs;
76
+ }
77
+
78
+ C10_HOST_DEVICE
79
+ references_holder& operator=(values vals) {
80
+ refs = vals;
81
+ return *this;
82
+ }
83
+
84
+ C10_HOST_DEVICE
85
+ references& data() {
86
+ return refs;
87
+ }
88
+
89
+ protected:
90
+ references refs;
91
+ };
92
+
93
+ // CompositeRandomAccessor is essentially a simplified version of
94
+ // a random access iterator over two random access iterators.
95
+ // TupleInfo should contain a variadic type `tuple`, and a method `tie`,
96
+ // which constructs a tuple of references from a variadic list of arguments.
97
+ template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
98
+ class CompositeRandomAccessor {
99
+ using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
100
+
101
+ using key_accessor_value_type =
102
+ typename std::iterator_traits<KeyAccessor>::value_type;
103
+ using value_accessor_value_type =
104
+ typename std::iterator_traits<ValueAccessor>::value_type;
105
+ using key_accessor_reference_type =
106
+ typename std::iterator_traits<KeyAccessor>::reference;
107
+ using value_accessor_reference_type =
108
+ typename std::iterator_traits<ValueAccessor>::reference;
109
+
110
+ using composite_value_type = typename TupleInfo::template tuple<
111
+ key_accessor_value_type,
112
+ value_accessor_value_type>;
113
+ using composite_reference = typename TupleInfo::template tuple<
114
+ key_accessor_reference_type,
115
+ value_accessor_reference_type>;
116
+
117
+ public:
118
+ using value_type = composite_value_type;
119
+ using reference = references_holder<composite_value_type, composite_reference>;
120
+ // Note that CompositeRandomAccessor does not hold key and values
121
+ // in a specific datastructure, which means that a pointer to a (key, value)
122
+ // is not defined. Hence we just use a pointer type of the KeyAccessor.
123
+ using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
124
+ using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
125
+ using iterator_category = std::random_access_iterator_tag;
126
+
127
+ C10_HOST_DEVICE
128
+ CompositeRandomAccessor() = default;
129
+
130
+ C10_HOST_DEVICE
131
+ CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
132
+ : keys(keys), values(values)
133
+ {}
134
+
135
+ // Pointer-like operations {
136
+ C10_HOST_DEVICE
137
+ reference operator*() const {
138
+ return TupleInfo::tie(*keys, *values);
139
+ }
140
+
141
+ // operator->() is supposed to return a pointer type.
142
+ // Since CompositeRandomAccessor does not hold pointers to pairs,
143
+ // we just return a pointer to a key.
144
+ C10_HOST_DEVICE
145
+ auto* operator->() const {
146
+ return keys.operator->();
147
+ }
148
+
149
+ C10_HOST_DEVICE
150
+ reference operator[](difference_type idx) {
151
+ return operator_brackets_proxy<self_type>(
152
+ CompositeRandomAccessor(keys + idx, values + idx)
153
+ );
154
+ }
155
+ // }
156
+
157
+ // Prefix/postfix increment/decrement {
158
+ C10_HOST_DEVICE
159
+ CompositeRandomAccessor& operator++() {
160
+ ++keys;
161
+ ++values;
162
+ return *this;
163
+ }
164
+
165
+ C10_HOST_DEVICE
166
+ CompositeRandomAccessor operator++(int) {
167
+ CompositeRandomAccessor copy(*this);
168
+ ++*this;
169
+ return copy;
170
+ }
171
+
172
+ C10_HOST_DEVICE
173
+ CompositeRandomAccessor& operator--() {
174
+ --keys;
175
+ --values;
176
+ return *this;
177
+ }
178
+
179
+ C10_HOST_DEVICE
180
+ CompositeRandomAccessor operator--(int) {
181
+ CompositeRandomAccessor copy(*this);
182
+ --*this;
183
+ return copy;
184
+ }
185
+ // }
186
+
187
+ // Arithmetic operations {
188
+ C10_HOST_DEVICE
189
+ CompositeRandomAccessor& operator+=(difference_type offset) {
190
+ keys += offset;
191
+ values += offset;
192
+ return *this;
193
+ }
194
+
195
+ C10_HOST_DEVICE
196
+ CompositeRandomAccessor operator+(difference_type offset) const {
197
+ return CompositeRandomAccessor(keys + offset, values + offset);
198
+ }
199
+
200
+ C10_HOST_DEVICE
201
+ friend CompositeRandomAccessor operator+(
202
+ difference_type offset,
203
+ const CompositeRandomAccessor& accessor
204
+ ) {
205
+ return accessor + offset;
206
+ }
207
+
208
+ C10_HOST_DEVICE
209
+ CompositeRandomAccessor& operator-=(difference_type offset) {
210
+ keys -= offset;
211
+ values -= offset;
212
+ return *this;
213
+ }
214
+
215
+ C10_HOST_DEVICE
216
+ CompositeRandomAccessor operator-(difference_type offset) const {
217
+ return CompositeRandomAccessor(keys - offset, values - offset);
218
+ }
219
+
220
+ C10_HOST_DEVICE
221
+ difference_type operator-(const CompositeRandomAccessor& other) const {
222
+ return keys - other.keys;
223
+ }
224
+ // }
225
+
226
+ // Comparison operators {
227
+ C10_HOST_DEVICE
228
+ bool operator==(const CompositeRandomAccessor& other) const {
229
+ return keys == other.keys;
230
+ }
231
+
232
+ C10_HOST_DEVICE
233
+ bool operator!=(const CompositeRandomAccessor& other) const {
234
+ return keys != other.keys;
235
+ }
236
+
237
+ C10_HOST_DEVICE
238
+ bool operator<(const CompositeRandomAccessor& other) const {
239
+ return keys < other.keys;
240
+ }
241
+
242
+ C10_HOST_DEVICE
243
+ bool operator<=(const CompositeRandomAccessor& other) const {
244
+ return keys <= other.keys;
245
+ }
246
+
247
+ C10_HOST_DEVICE
248
+ bool operator>(const CompositeRandomAccessor& other) const {
249
+ return keys > other.keys;
250
+ }
251
+
252
+ C10_HOST_DEVICE
253
+ bool operator>=(const CompositeRandomAccessor& other) const {
254
+ return keys >= other.keys;
255
+ }
256
+ // }
257
+
258
+ protected:
259
+ KeyAccessor keys;
260
+ ValueAccessor values;
261
+ };
262
+
263
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/TensorUtils.h>
4
+ #include <ATen/detail/CUDAHooksInterface.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <c10/util/env.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ #include <utility>
10
+
11
+ namespace at::native {
12
+
13
+ using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
14
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
15
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
16
+ DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
17
+ using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
18
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
19
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
20
+ DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
21
+ using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
22
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
23
+ at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
24
+ DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
25
+ using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
26
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
27
+ at::IntArrayRef, int64_t, std::array<bool,3>);
28
+ DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
29
+ using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
30
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
31
+ at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
32
+ DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
33
+ using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
34
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
35
+ at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
36
+ DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
37
+ using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
38
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
39
+ at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
40
+ DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
41
+ using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
42
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
43
+ at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
44
+ DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
45
+ using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
46
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
47
+ at::IntArrayRef, int64_t, std::array<bool,3>);
48
+ DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
49
+ using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
50
+ IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
51
+ DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
52
+ using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
53
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
54
+ at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
55
+ DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
56
+ using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
57
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
58
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
59
+ DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
60
+ using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
61
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
62
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
63
+ DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
64
+ using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
65
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
66
+ at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
67
+ DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
68
+ using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
69
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
70
+ at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
71
+ DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
72
+
73
+ namespace {
74
+ bool is_cudnnv8_heuristic_mode_b() {
75
+ static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
76
+ return is_cudnnv8_heuristic_mode_b;
77
+ }
78
+ }
79
+
80
+ inline bool cudnnv8_enabled_check_debug() {
81
+ static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
82
+ static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
83
+ static uint8_t cudnnv8_debugcount = 0;
84
+ if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
85
+ TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
86
+ cudnnv8_debugcount++;
87
+ }
88
+ return cudnnv8_flag == 1;
89
+ }
90
+
91
+ inline bool cudnnv8_use_heur_mode_b() {
92
+ return is_cudnnv8_heuristic_mode_b();
93
+ }
94
+
95
+ // Keep in sync with py::enum_ in Module.cpp
96
+ enum class ConvBackend {
97
+ CudaDepthwise2d,
98
+ CudaDepthwise3d,
99
+ Cudnn,
100
+ CudnnTranspose,
101
+ Empty,
102
+ Miopen,
103
+ MiopenDepthwise,
104
+ MiopenTranspose,
105
+ Mkldnn,
106
+ MkldnnTranspose,
107
+ MkldnnEmpty,
108
+ NnpackSpatial,
109
+ Overrideable,
110
+ Slow2d,
111
+ Slow3d,
112
+ SlowDilated2d,
113
+ SlowDilated3d,
114
+ SlowTranspose2d,
115
+ SlowTranspose3d,
116
+ Winograd3x3Depthwise,
117
+ Xnnpack2d,
118
+ Mps,
119
+ MpsTranspose,
120
+ };
121
+
122
+ // Overload for selecting the convolution backend from the full set of convolution inputs.
123
+ // This overload is exposed to python for testing, etc.
124
+ TORCH_API ConvBackend select_conv_backend(
125
+ const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
126
+ SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
127
+ bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
128
+
129
+ TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
130
+ const Tensor& weight,
131
+ const ConvBackend backend);
132
+
133
+ // ---------------------------------------------------------------------
134
+ //
135
+ // Math
136
+ //
137
+ // ---------------------------------------------------------------------
138
+
139
+ constexpr int input_batch_size_dim = 0; // also grad_input
140
+ constexpr int input_channels_dim = 1;
141
+ constexpr int output_batch_size_dim = 0; // also grad_output
142
+ constexpr int output_channels_dim = 1;
143
+ constexpr int weight_output_channels_dim = 0;
144
+ constexpr int weight_input_channels_dim = 1;
145
+
146
+ // Often written as 2 + max_dim (extra dims for batch size and channels)
147
+ constexpr int max_dim = 3;
148
+
149
+ // ---------------------------------------------------------------------
150
+ //
151
+ // Checking
152
+ //
153
+ // ---------------------------------------------------------------------
154
+
155
+ // Used on pad, stride and dilation
156
+ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
157
+ {
158
+ TORCH_CHECK(args.size() <= expected_size,
159
+ "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
160
+ expected_size, " (while checking arguments for ", c, ")");
161
+ TORCH_CHECK(args.size() >= expected_size,
162
+ "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
163
+ expected_size, " (while checking arguments for ", c, ")");
164
+
165
+ auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
166
+ if (num_negative_values > 0){
167
+ std::stringstream ss;
168
+ ss << arg_name << " should be greater than zero but got (";
169
+ std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
170
+ ss << args.back() << ")" << " (while checking arguments for " << c << ")";
171
+ AT_ERROR(ss.str());
172
+ }
173
+ }
174
+
175
+
176
+ // NOTE [ Convolution checks ]
177
+ //
178
+ // NB: For many call sites, it is not strictly necessary to check all of
179
+ // these relationships (for example, for forward convolution, we compute
180
+ // the size of output ourselves, so we don't actually need to check
181
+ // output. However, writing a single function that does everything
182
+ // means we get to reuse it for both forwards and all backwards
183
+ // variants, even when the set of "real" inputs varies. The magic of
184
+ // relational computing!
185
+ //
186
+ // (There is one downside, which is that it is slightly harder to write
187
+ // error messages which are able to distinguish between real inputs
188
+ // (which the user can change) and computed inputs (which the user can
189
+ // only indirectly affect). It would be an interesting exercise to
190
+ // come up with a general framework to handle such situations.)
191
+ inline void convolution_shape_check(
192
+ CheckedFrom c,
193
+ const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
194
+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
195
+ {
196
+ check_args(c, padding, input->dim() - 2, "padding");
197
+ check_args(c, stride, padding.size(), "stride");
198
+ check_args(c, dilation, padding.size(), "dilation");
199
+
200
+ // Input
201
+ checkDimRange(c, input, 3, 6 /* exclusive */);
202
+ checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
203
+
204
+ // Weight
205
+ checkSameDim(c, input, weight);
206
+
207
+ // TODO: check that output->size() matches output_sizes
208
+ // TODO: check that weight matches output->sizes()
209
+ checkSameDim(c, input, output);
210
+ }
211
+
212
+ // NB: conv_output_size and conv_input_size are not bijections,
213
+ // as conv_output_size loses information; this is why conv_input_size
214
+ // takes an extra output_padding argument to resolve the ambiguity.
215
+
216
+ template <typename T>
217
+ inline std::vector<T> _conv_output_size(
218
+ ArrayRef<T> input_size, ArrayRef<T> weight_size,
219
+ ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
220
+ ) {
221
+ // ASSERT(input_size.size() > 2)
222
+ // ASSERT(input_size.size() == weight_size.size())
223
+ bool has_dilation = !dilation.empty();
224
+ auto dim = input_size.size();
225
+ std::vector<T> output_size(dim);
226
+ output_size[0] = input_size[input_batch_size_dim];
227
+ output_size[1] = weight_size[weight_output_channels_dim];
228
+ for (const auto d : c10::irange(2, dim)) {
229
+ auto dilation_ = has_dilation ? dilation[d - 2] : 1;
230
+ auto kernel = dilation_ * (weight_size[d] - 1) + 1;
231
+ output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
232
+ }
233
+ return output_size;
234
+ }
235
+
236
+ inline std::vector<int64_t> conv_output_size(
237
+ IntArrayRef input_size, IntArrayRef weight_size,
238
+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
239
+ ) {
240
+ return _conv_output_size(input_size, weight_size, padding, stride, dilation);
241
+ }
242
+
243
+ inline std::vector<c10::SymInt> conv_output_size(
244
+ SymIntArrayRef input_size, SymIntArrayRef weight_size,
245
+ SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
246
+ ) {
247
+ return _conv_output_size(input_size, weight_size, padding, stride, dilation);
248
+ }
249
+
250
+ template <typename T>
251
+ std::vector<T> _conv_input_size(
252
+ ArrayRef<T> output_size, ArrayRef<T> weight_size,
253
+ ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
254
+ ) {
255
+ // ASSERT(output_size.size() > 2)
256
+ // ASSERT(output_size.size() == weight_size.size())
257
+ auto dim = output_size.size();
258
+ std::vector<T> input_size(dim);
259
+ input_size[0] = output_size[output_batch_size_dim];
260
+ input_size[1] = weight_size[weight_input_channels_dim] * groups;
261
+ for (const auto d : c10::irange(2, dim)) {
262
+ auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
263
+ input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
264
+ kernel + output_padding[d - 2];
265
+ }
266
+ return input_size;
267
+ }
268
+
269
+ inline std::vector<c10::SymInt> conv_input_size(
270
+ SymIntArrayRef output_size, SymIntArrayRef weight_size,
271
+ SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
272
+ ) {
273
+ return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
274
+ }
275
+
276
+ inline std::vector<int64_t> conv_input_size(
277
+ IntArrayRef output_size, IntArrayRef weight_size,
278
+ IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
279
+ ) {
280
+ return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
281
+ }
282
+
283
+ template <typename T>
284
+ std::vector<T> _conv_weight_size(
285
+ ArrayRef<T> input_size, ArrayRef<T> output_size,
286
+ ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
287
+ ) {
288
+ auto dim = input_size.size();
289
+ std::vector<T> weight_size(dim);
290
+ weight_size[0] = output_size[1];
291
+ weight_size[1] = input_size[1] / groups;
292
+ for (const auto d : c10::irange(2, dim)) {
293
+ auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
294
+ + padding[d - 2] * 2 - output_padding[d - 2];
295
+ weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
296
+ }
297
+ return weight_size;
298
+ }
299
+
300
+ inline std::vector<c10::SymInt> conv_weight_size(
301
+ SymIntArrayRef input_size, SymIntArrayRef output_size,
302
+ SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
303
+ ) {
304
+ return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
305
+ }
306
+
307
+ inline std::vector<int64_t> conv_weight_size(
308
+ IntArrayRef input_size, IntArrayRef output_size,
309
+ IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
310
+ ) {
311
+ return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
312
+ }
313
+
314
+ inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
315
+ std::vector<int64_t> shape(dim, 1);
316
+ shape[1] = -1;
317
+ return bias.reshape(shape);
318
+ }
319
+
320
+ inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
321
+ // disable NHWC for float64 input.
322
+ if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
323
+ input.scalar_type() == at::kDouble ||
324
+ weight.scalar_type() == at::kDouble) {
325
+ return at::MemoryFormat::Contiguous;
326
+ }
327
+ long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
328
+ auto input_memory_format = input.suggest_memory_format();
329
+ auto weight_memory_format = weight.suggest_memory_format();
330
+ auto weight_ndim = weight.ndimension();
331
+
332
+ bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
333
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
334
+ (weight_memory_format == at::MemoryFormat::ChannelsLast)
335
+ );
336
+ if (can_use_cudnn_channels_last_2d) {
337
+ return at::MemoryFormat::ChannelsLast;
338
+ }
339
+
340
+ bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
341
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
342
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
343
+ );
344
+ if (can_use_cudnn_channels_last_3d) {
345
+ return at::MemoryFormat::ChannelsLast3d;
346
+ }
347
+
348
+ return at::MemoryFormat::Contiguous;
349
+ }
350
+
351
+ // controls whether emptyCache will be called following cudnn conv benchmarking
352
+ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
353
+ TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
354
+
355
+
356
+ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
357
+
358
+ // disable NHWC for float64 input.
359
+ if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
360
+ input.scalar_type() == at::kDouble ||
361
+ weight.scalar_type() == at::kDouble) {
362
+ return false;
363
+ }
364
+
365
+ bool can_use_miopen_channels_last_2d = false;
366
+ // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
367
+ // See #64427
368
+ static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
369
+
370
+ auto input_memory_format = input.suggest_memory_format();
371
+ auto weight_memory_format = weight.suggest_memory_format();
372
+
373
+ can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
374
+ ( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
375
+ (weight_memory_format == at::MemoryFormat::ChannelsLast) )
376
+ );
377
+
378
+ bool can_use_miopen_channels_last_3d = false;
379
+
380
+ return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
381
+ }
382
+
383
+ inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
384
+
385
+ // disable NHWC for float64 input.
386
+ if (input.scalar_type() == at::kDouble ||
387
+ weight.scalar_type() == at::kDouble) {
388
+ return false;
389
+ }
390
+
391
+ // disable NHWC for MkldnnCPU tensor.
392
+ if (input.is_mkldnn() || weight.is_mkldnn()) {
393
+ return false;
394
+ }
395
+
396
+ auto input_memory_format = input.suggest_memory_format();
397
+ auto weight_memory_format = weight.suggest_memory_format();
398
+
399
+ bool can_use_mkldnn_channels_last_2d =
400
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
401
+ (weight_memory_format == at::MemoryFormat::ChannelsLast);
402
+
403
+ bool can_use_mkldnn_channels_last_3d =
404
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
405
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
406
+
407
+ return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
408
+ }
409
+
410
+ inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
411
+
412
+ auto input_memory_format = input.suggest_memory_format();
413
+ auto weight_memory_format = weight.suggest_memory_format();
414
+
415
+ bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
416
+ (input_memory_format == at::MemoryFormat::ChannelsLast) || (
417
+ weight_memory_format == at::MemoryFormat::ChannelsLast));
418
+
419
+ return can_use_thnn_channels_last_2d;
420
+ }
421
+
422
+ inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
423
+
424
+ // check layout only for xpu tensor.
425
+ if (!input.is_xpu() || !weight.is_xpu()) {
426
+ return false;
427
+ }
428
+
429
+ // disable NHWC for float64 input.
430
+ if (input.scalar_type() == at::kDouble ||
431
+ weight.scalar_type() == at::kDouble) {
432
+ return false;
433
+ }
434
+
435
+ auto input_memory_format = input.suggest_memory_format();
436
+ auto weight_memory_format = weight.suggest_memory_format();
437
+
438
+ bool can_use_xpu_channels_last_2d =
439
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
440
+ (weight_memory_format == at::MemoryFormat::ChannelsLast);
441
+
442
+ bool can_use_xpu_channels_last_3d =
443
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
444
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
445
+
446
+ return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
447
+ }
448
+
449
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+
3
+ namespace at::native {
4
+
5
+ std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
6
+ const Tensor& grad_output,
7
+ const Tensor& self,
8
+ const Tensor& weight,
9
+ IntArrayRef kernel_size,
10
+ IntArrayRef stride,
11
+ IntArrayRef padding,
12
+ std::array<bool, 3> output_mask);
13
+
14
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+
7
+ class Tensor;
8
+ struct TensorIterator;
9
+ class TensorBase;
10
+
11
+ namespace native {
12
+
13
+ using copy_fn = void (*)(TensorIterator&, bool non_blocking);
14
+
15
+ DECLARE_DISPATCH(copy_fn, copy_stub);
16
+
17
+ TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
18
+
19
+ } // namespace native
20
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+
8
+ namespace native {
9
+
10
+ using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
11
+
12
+ DECLARE_DISPATCH(cross_fn, cross_stub);
13
+
14
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <algorithm>
4
+ #include <vector>
5
+
6
+ #include <ATen/div_rtn.h>
7
+ #include <ATen/core/Tensor.h>
8
+ #include <c10/util/irange.h>
9
+
10
+ #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
11
+ TORCH_CHECK( \
12
+ T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
13
+ "Need " #T " of dimension ", \
14
+ DIM, \
15
+ " and " #T ".size[", \
16
+ DIM_SIZE, \
17
+ "] == ", \
18
+ SIZE, \
19
+ " but got input to be of shape ", \
20
+ T.sizes())
21
+
22
+ namespace at::native::internal {
23
+ namespace {
24
+ inline bool all_positive(IntArrayRef& arr) {
25
+ return std::all_of(
26
+ arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
27
+ }
28
+
29
+ inline bool all_nonnegative(std::vector<int64_t>& arr) {
30
+ return std::all_of(
31
+ arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
32
+ }
33
+
34
+ } // namespace
35
+
36
+ // calculate the rear part of output tensor sizes
37
+ template <int64_t dim>
38
+ std::vector<int64_t> get_output_size(
39
+ const Tensor& input,
40
+ IntArrayRef kernel_size,
41
+ IntArrayRef stride_size,
42
+ IntArrayRef pad_size,
43
+ IntArrayRef dilation_size) {
44
+ std::vector<int64_t> sizes;
45
+ for (const auto index : c10::irange(dim)) {
46
+ sizes.push_back(
47
+ div_rtn<int64_t>(
48
+ input.size(index + input.dim() - dim) + 2 * pad_size[index] -
49
+ (dilation_size[index] * (kernel_size[index] - 1) + 1),
50
+ stride_size[index]) +
51
+ 1);
52
+ }
53
+ return sizes;
54
+ }
55
+
56
+ // calculate the sizes of output tensor
57
+ template <int64_t dim>
58
+ std::vector<int64_t> get_output_size(
59
+ const Tensor& input,
60
+ const Tensor& weight,
61
+ IntArrayRef kernel_size,
62
+ IntArrayRef stride_size,
63
+ IntArrayRef pad_size,
64
+ IntArrayRef dilation_size) {
65
+ auto output_size = get_output_size<dim>(
66
+ input, kernel_size, stride_size, pad_size, dilation_size);
67
+ output_size.insert(output_size.begin(), weight.size(0));
68
+ if (input.dim() == dim + 2) {
69
+ output_size.insert(output_size.begin(), input.size(0));
70
+ }
71
+ return output_size;
72
+ }
73
+ /*
74
+ slow_conv_dilated_shape_check - check user-input to dilated convolution
75
+ forward and backward functions.
76
+ */
77
+ template <int64_t dim>
78
+ void slow_conv_dilated_shape_check(
79
+ const Tensor& input,
80
+ const Tensor& weight,
81
+ const Tensor& bias,
82
+ const Tensor& grad_output,
83
+ IntArrayRef kernel_size,
84
+ IntArrayRef stride_size,
85
+ IntArrayRef pad_size,
86
+ IntArrayRef dilation_size) {
87
+ /*
88
+ When the following tensors are defined:
89
+
90
+ bias, grad_weight, grad_output
91
+
92
+ then these are assumed to be contiguous without checking
93
+ because of these tensors are made contiguous by calling
94
+ .contiguous() method or by resizing of zero-sized tensors in
95
+ forward/backward functions.
96
+
97
+ When grad_weight is defined then it is assumed without
98
+ checking to have the same shape as weight, see backward
99
+ functions.
100
+ */
101
+ // Check size arguments
102
+ TORCH_CHECK(
103
+ kernel_size.size() == dim,
104
+ "kernel sizes length should be ",
105
+ dim,
106
+ ", but got ",
107
+ kernel_size.size());
108
+ TORCH_CHECK(
109
+ stride_size.size() == dim,
110
+ "strides length should be ",
111
+ dim,
112
+ ", but got ",
113
+ stride_size.size());
114
+ TORCH_CHECK(
115
+ dilation_size.size() == dim,
116
+ "dilations length should be ",
117
+ dim,
118
+ ", but got ",
119
+ dilation_size.size());
120
+ TORCH_CHECK(
121
+ pad_size.size() == dim,
122
+ "pads length should be ",
123
+ dim,
124
+ ", but got ",
125
+ pad_size.size());
126
+
127
+ TORCH_CHECK(
128
+ all_positive(kernel_size),
129
+ "kernel size should be greater than zero, but got ",
130
+ kernel_size);
131
+ TORCH_CHECK(
132
+ all_positive(stride_size),
133
+ "stride should be greater than zero, but got ",
134
+ stride_size);
135
+ TORCH_CHECK(
136
+ all_positive(dilation_size),
137
+ "dilation should be greater than zero, but got ",
138
+ dilation_size);
139
+
140
+ // check input
141
+ TORCH_CHECK(input.defined(), "input must be defined");
142
+ bool is_batch = input.dim() == dim + 2;
143
+ int64_t n = (is_batch ? 2 : 1);
144
+ int64_t ndim = n + dim;
145
+ if (!is_batch) {
146
+ // input dim has to be dim + 1 if not batched
147
+ TORCH_CHECK(
148
+ input.dim() == dim + 1,
149
+ "input must be 4D or 5D tensor but got ",
150
+ input.dim(),
151
+ "D tensor");
152
+ }
153
+
154
+ // check output sizes
155
+ auto output_size = get_output_size<dim>(
156
+ input, kernel_size, stride_size, pad_size, dilation_size);
157
+
158
+ TORCH_CHECK(
159
+ all_nonnegative(output_size),
160
+ "calculated output size ",
161
+ output_size,
162
+ " is too small (all sizes must be non-negative)");
163
+
164
+ // check weight
165
+ TORCH_CHECK(weight.defined(), "weight must be defined");
166
+ TORCH_CHECK(
167
+ weight.dim() == dim + 2,
168
+ "weight must be ",
169
+ dim + 2,
170
+ "D tensor but got ",
171
+ weight.dim(),
172
+ "D tensor dim=",
173
+ dim);
174
+ TORCH_CHECK(
175
+ weight.sizes().slice(2) == kernel_size,
176
+ "weight[2:] shape ",
177
+ weight.sizes().slice(2),
178
+ " must be equal to kernel_size ",
179
+ kernel_size);
180
+
181
+ TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
182
+
183
+ // check bias when present
184
+ if (bias.defined()) {
185
+ TORCH_CHECK(
186
+ bias.dim() == 1,
187
+ "bias must be 1D tensor but got ",
188
+ bias.dim(),
189
+ "D tensor");
190
+ TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
191
+ }
192
+
193
+ // check grad_output when present
194
+ if (grad_output.defined()) {
195
+ TORCH_CHECK(
196
+ grad_output.dim() == ndim,
197
+ "grad_output must be ",
198
+ ndim,
199
+ "D tensor but got ",
200
+ grad_output.dim(),
201
+ "D tensor");
202
+ if (is_batch) {
203
+ TORCH_CHECK(
204
+ grad_output.size(0) == input.size(0),
205
+ "grad_output.size(0)=",
206
+ grad_output.size(0),
207
+ " must be input.size(0)=",
208
+ input.size(0));
209
+ }
210
+ TORCH_CHECK(
211
+ grad_output.size(n - 1) == weight.size(0),
212
+ "grad_output.size(",
213
+ n - 1,
214
+ ")=",
215
+ grad_output.size(n - 1),
216
+ " must be weight.size(0)=",
217
+ weight.size(0));
218
+ TORCH_CHECK(
219
+ grad_output.sizes().slice(n) == output_size,
220
+ "grad_output[",
221
+ n,
222
+ ":] shape",
223
+ grad_output.sizes().slice(n),
224
+ " must be equal to output size ",
225
+ output_size);
226
+ }
227
+ }
228
+
229
+ } // namespace at::native::internal
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/DeviceType.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/Array.h>
6
+
7
+ #include <atomic>
8
+ #include <utility>
9
+ #include <variant>
10
+
11
+ // Implements instruction set specific function dispatch.
12
+ //
13
+ // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
14
+ // compiled multiple times with different compiler flags (e.g. -mavx2). A
15
+ // DispatchStub contains a table of function pointers for a kernel. At runtime,
16
+ // the fastest available kernel is chosen based on the features reported by
17
+ // cpuinfo.
18
+ //
19
+ // Example:
20
+ //
21
+ // In native/MyKernel.h:
22
+ // using fn_type = void(*)(const Tensor& x);
23
+ // DECLARE_DISPATCH(fn_type, stub);
24
+ //
25
+ // In native/MyKernel.cpp
26
+ // DEFINE_DISPATCH(stub);
27
+ //
28
+ // In native/cpu/MyKernel.cpp:
29
+ // namespace {
30
+ // // use anonymous namespace so that different cpu versions won't conflict
31
+ // void kernel(const Tensor& x) { ... }
32
+ // }
33
+ // REGISTER_DISPATCH(stub, &kernel);
34
+ //
35
+ // To call:
36
+ // stub(kCPU, tensor);
37
+ //
38
+ // TODO: CPU instruction set selection should be folded into whatever
39
+ // the main dispatch mechanism is.
40
+ //
41
+ // Supported device types for registration:
42
+ // - CPU: Central Processing Unit
43
+ // - CUDA: NVIDIA GPUs
44
+ // - HIP: AMD GPUs
45
+ // - MPS: Apple Silicon GPUs (Metal Performance Shaders)
46
+ // - MTIA: Meta Training and Inference Devices
47
+ // - XPU: Intel GPUs
48
+ // - PrivateUse1: Reserved for private/custom device types
49
+ //
50
+ // If you want to update the list of supported devices, add a new dispatch_ptr
51
+ // member in DispatchStubImpl.h and update the get_call_ptr switch.
52
+ // As well you will need to update the inlined list in 'is_device_supported`
53
+ //
54
+ //
55
+ // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
56
+ C10_CLANG_DIAGNOSTIC_PUSH()
57
+ C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
58
+
59
+ namespace at::native {
60
+
61
+ enum class CPUCapability {
62
+ DEFAULT = 0,
63
+ #if defined(HAVE_VSX_CPU_DEFINITION)
64
+ VSX = 1,
65
+ #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
66
+ ZVECTOR = 1,
67
+ #else
68
+ AVX2 = 1,
69
+ AVX512 = 2,
70
+ #endif
71
+ NUM_OPTIONS
72
+ };
73
+
74
+ // Enum for error types
75
+ enum class ErrorType {
76
+ MissingDeviceKernel,
77
+ DeviceNotSupported
78
+ };
79
+
80
+ // Alias for the return type using std::variant
81
+ using DispatchResult = std::variant<void*, ErrorType>;
82
+
83
+ CPUCapability get_cpu_capability();
84
+
85
+ template <typename FnPtr, typename T>
86
+ struct DispatchStub;
87
+
88
+ /**
89
+ * The sole purpose of this class is to outline methods that don't need to be
90
+ * specialized or otherwise inlined and duplicated (by the compiler due to
91
+ * template expansion), since it causes size bloat if there are a significant
92
+ * number of specialization of the DispatchStub<> class.
93
+ */
94
+ struct TORCH_API DispatchStubImpl {
95
+
96
+ // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
97
+ // pointer for a given device type. If the call pointer is not found,
98
+ // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
99
+ // The main difference between try_get_call_ptr() and get_call_ptr() is that
100
+ // try_get_call_ptr() will return the ErrorType and not raise an exception.
101
+ DispatchResult try_get_call_ptr(
102
+ c10::DeviceType device_type
103
+ , void *DEFAULT
104
+ #ifdef HAVE_AVX512_CPU_DEFINITION
105
+ , void *AVX512
106
+ #endif
107
+ #ifdef HAVE_AVX2_CPU_DEFINITION
108
+ , void *AVX2
109
+ #endif
110
+ #ifdef HAVE_VSX_CPU_DEFINITION
111
+ , void *VSX
112
+ #endif
113
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
114
+ , void *ZVECTOR
115
+ #endif
116
+ );
117
+
118
+ // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
119
+ // raise an exception.
120
+ DispatchResult try_choose_cpu_impl(
121
+ void *DEFAULT
122
+ #ifdef HAVE_AVX512_CPU_DEFINITION
123
+ , void *AVX512
124
+ #endif
125
+ #ifdef HAVE_AVX2_CPU_DEFINITION
126
+ , void *AVX2
127
+ #endif
128
+ #ifdef HAVE_VSX_CPU_DEFINITION
129
+ , void *VSX
130
+ #endif
131
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
132
+ , void *ZVECTOR
133
+ #endif
134
+ );
135
+
136
+
137
+ void* get_call_ptr(
138
+ c10::DeviceType device_type
139
+ , void *DEFAULT
140
+ #ifdef HAVE_AVX512_CPU_DEFINITION
141
+ , void *AVX512
142
+ #endif
143
+ #ifdef HAVE_AVX2_CPU_DEFINITION
144
+ , void *AVX2
145
+ #endif
146
+ #ifdef HAVE_VSX_CPU_DEFINITION
147
+ , void *VSX
148
+ #endif
149
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
150
+ , void *ZVECTOR
151
+ #endif
152
+ );
153
+
154
+ /**
155
+ * The CPU Dispatch actual method is chosen in decreasing order of preference by
156
+ * DispatchStubImpl::choose_cpu_impl() in case none is found by
157
+ * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
158
+ */
159
+ void* choose_cpu_impl(
160
+ void *DEFAULT
161
+ #ifdef HAVE_AVX512_CPU_DEFINITION
162
+ , void *AVX512
163
+ #endif
164
+ #ifdef HAVE_AVX2_CPU_DEFINITION
165
+ , void *AVX2
166
+ #endif
167
+ #ifdef HAVE_VSX_CPU_DEFINITION
168
+ , void *VSX
169
+ #endif
170
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
171
+ , void *ZVECTOR
172
+ #endif
173
+ );
174
+
175
+ // Fixing dispatch error in Windows debug builds.
176
+ // See https://github.com/pytorch/pytorch/issues/22681 for more details.
177
+ #if defined(_MSC_VER) && defined(_DEBUG)
178
+ std::atomic<void*> cpu_dispatch_ptr;
179
+ void* cuda_dispatch_ptr;
180
+ void* hip_dispatch_ptr;
181
+ void* mps_dispatch_ptr;
182
+ void* mtia_dispatch_ptr;
183
+ #if defined(USE_XPU)
184
+ void* xpu_dispatch_ptr;
185
+ #endif
186
+ void* privateuse1_dispatch_ptr;
187
+ #else
188
+ std::atomic<void*> cpu_dispatch_ptr{nullptr};
189
+ void* cuda_dispatch_ptr = nullptr;
190
+ void* hip_dispatch_ptr = nullptr;
191
+ void* mps_dispatch_ptr = nullptr;
192
+ void* mtia_dispatch_ptr = nullptr;
193
+ #if defined(USE_XPU)
194
+ void* xpu_dispatch_ptr = nullptr;
195
+ #endif
196
+ void* privateuse1_dispatch_ptr = nullptr;
197
+ #endif
198
+ };
199
+
200
+ template <typename rT, typename T, typename... Args>
201
+ struct DispatchStub<rT (*)(Args...), T> {
202
+ using FnPtr = rT (*) (Args...);
203
+
204
+ DispatchStub() = default;
205
+ DispatchStub(const DispatchStub&) = delete;
206
+ DispatchStub& operator=(const DispatchStub&) = delete;
207
+
208
+ private:
209
+ FnPtr get_call_ptr(const c10::DeviceType device_type) {
210
+ return reinterpret_cast<FnPtr>(
211
+ impl.get_call_ptr(device_type
212
+ , reinterpret_cast<void*>(DEFAULT)
213
+ #ifdef HAVE_AVX512_CPU_DEFINITION
214
+ , reinterpret_cast<void*>(AVX512)
215
+ #endif
216
+ #ifdef HAVE_AVX2_CPU_DEFINITION
217
+ , reinterpret_cast<void*>(AVX2)
218
+ #endif
219
+ #ifdef HAVE_VSX_CPU_DEFINITION
220
+ , reinterpret_cast<void*>(VSX)
221
+ #endif
222
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
223
+ , reinterpret_cast<void*>(ZVECTOR)
224
+ #endif
225
+ )
226
+ );
227
+ }
228
+
229
+ public:
230
+ template <typename... ArgTypes>
231
+ rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
232
+ FnPtr call_ptr = get_call_ptr(device_type);
233
+ return (*call_ptr)(std::forward<ArgTypes>(args)...);
234
+ }
235
+
236
+ void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
237
+ impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
238
+ }
239
+
240
+ #if defined(USE_XPU)
241
+ void set_xpu_dispatch_ptr(FnPtr fn_ptr){
242
+ impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
243
+ }
244
+ #endif
245
+
246
+ void set_hip_dispatch_ptr(FnPtr fn_ptr) {
247
+ impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
248
+ }
249
+
250
+ void set_mps_dispatch_ptr(FnPtr fn_ptr) {
251
+ impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
252
+ }
253
+
254
+ void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
255
+ impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
256
+ }
257
+
258
+ void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
259
+ impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
260
+ }
261
+
262
+ // Returns true if the dispatcher has a kernel registered for this device
263
+ // type.
264
+ bool is_device_supported(const c10::DeviceType device_type) {
265
+ auto result = impl.try_get_call_ptr(device_type
266
+ , reinterpret_cast<void*>(DEFAULT)
267
+ #ifdef HAVE_AVX512_CPU_DEFINITION
268
+ , reinterpret_cast<void*>(AVX512)
269
+ #endif
270
+ #ifdef HAVE_AVX2_CPU_DEFINITION
271
+ , reinterpret_cast<void*>(AVX2)
272
+ #endif
273
+ #ifdef HAVE_VSX_CPU_DEFINITION
274
+ , reinterpret_cast<void*>(VSX)
275
+ #endif
276
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
277
+ , reinterpret_cast<void*>(ZVECTOR)
278
+ #endif
279
+ );
280
+ if (std::holds_alternative<ErrorType>(result)){
281
+ return false;
282
+ }
283
+ return true;
284
+ };
285
+
286
+ static TORCH_API FnPtr DEFAULT;
287
+ #ifdef HAVE_AVX512_CPU_DEFINITION
288
+ static TORCH_API FnPtr AVX512;
289
+ #endif
290
+ #ifdef HAVE_AVX2_CPU_DEFINITION
291
+ static TORCH_API FnPtr AVX2;
292
+ #endif
293
+ #ifdef HAVE_VSX_CPU_DEFINITION
294
+ static TORCH_API FnPtr VSX;
295
+ #endif
296
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
297
+ static TORCH_API FnPtr ZVECTOR;
298
+ #endif
299
+ private:
300
+ DispatchStubImpl impl;
301
+ };
302
+
303
+ namespace {
304
+ template <typename DispatchStub>
305
+ struct RegisterCUDADispatch {
306
+ RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
307
+ stub.set_cuda_dispatch_ptr(value);
308
+ }
309
+ };
310
+
311
+ template <typename DispatchStub>
312
+ struct RegisterXPUDispatch {
313
+ RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
314
+ stub.set_xpu_dispatch_ptr(value);
315
+ }
316
+ };
317
+
318
+ template <typename DispatchStub>
319
+ struct RegisterMPSDispatch {
320
+ RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
321
+ stub.set_mps_dispatch_ptr(value);
322
+ }
323
+ };
324
+
325
+ template <typename DispatchStub>
326
+ struct RegisterHIPDispatch {
327
+ RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
328
+ // TODO: make this point at hip_dispatch_ptr
329
+ stub.set_cuda_dispatch_ptr(value);
330
+ }
331
+ };
332
+
333
+ template <typename DispatchStub>
334
+ struct RegisterMTIADispatch {
335
+ RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
336
+ stub.set_mtia_dispatch_ptr(value);
337
+ }
338
+ };
339
+
340
+ template <typename DispatchStub>
341
+ struct RegisterPRIVATEUSE1Dispatch {
342
+ RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
343
+ stub.set_privateuse1_dispatch_ptr(value);
344
+ }
345
+ };
346
+
347
+ } // anonymous namespace
348
+ // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
349
+ // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
350
+ // adding parentheses and using helper struct to get rid of the parentheses, do
351
+ // not work with MSVC. So do a `using`-declaration if you need to pass in such
352
+ // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
353
+ #define DECLARE_DISPATCH(fn, name) \
354
+ struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
355
+ name##_DECLARE_DISPATCH_type() = default; \
356
+ name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
357
+ name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
358
+ }; \
359
+ extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
360
+
361
+ #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
362
+
363
+ #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
364
+ template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
365
+
366
+ #ifdef HAVE_AVX512_CPU_DEFINITION
367
+ #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
368
+ #else
369
+ #define REGISTER_AVX512_DISPATCH(name, fn)
370
+ #endif
371
+
372
+ #ifdef HAVE_AVX2_CPU_DEFINITION
373
+ #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
374
+ #else
375
+ #define REGISTER_AVX2_DISPATCH(name, fn)
376
+ #endif
377
+
378
+ #ifdef HAVE_VSX_CPU_DEFINITION
379
+ #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
380
+ #else
381
+ #define REGISTER_VSX_DISPATCH(name, fn)
382
+ #endif
383
+
384
+ #ifdef HAVE_ZVECTOR_CPU_DEFINITION
385
+ #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
386
+ #else
387
+ #define REGISTER_ZVECTOR_DISPATCH(name, fn)
388
+ #endif
389
+
390
+ // Macro to register the same kernel for all CPU arch types. This is useful
391
+ // if a kernel does not benefit from being recompiled across different arch types.
392
+ #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
393
+ REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
394
+ REGISTER_AVX512_DISPATCH(name, fn) \
395
+ REGISTER_AVX2_DISPATCH(name, fn) \
396
+ REGISTER_VSX_DISPATCH(name, fn) \
397
+ REGISTER_ZVECTOR_DISPATCH(name, fn)
398
+
399
+ #define REGISTER_NO_CPU_DISPATCH(name) \
400
+ REGISTER_ALL_CPU_DISPATCH(name, nullptr)
401
+
402
+ #define REGISTER_CUDA_DISPATCH(name, fn) \
403
+ static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
404
+
405
+ #define REGISTER_XPU_DISPATCH(name, fn) \
406
+ static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
407
+
408
+ #define REGISTER_HIP_DISPATCH(name, fn) \
409
+ static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
410
+
411
+ #define REGISTER_MPS_DISPATCH(name, fn) \
412
+ static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
413
+
414
+ #define REGISTER_MTIA_DISPATCH(name, fn) \
415
+ static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
416
+
417
+ #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
418
+ static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
419
+
420
+ // NB: This macro must be used in an actual 'cu' file; if you try using
421
+ // it from a 'cpp' file it will not work!
422
+ #if defined(__CUDACC__)
423
+ #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
424
+ #elif defined(__HIPCC__)
425
+ // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
426
+ // is HIP in the PyTorch HIPify build.
427
+ #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
428
+ // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
429
+ #elif defined(__OBJC__) && defined(USE_MPS)
430
+ // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
431
+ #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
432
+ #elif defined(CPU_CAPABILITY)
433
+ // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
434
+ // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
435
+ #ifdef CPU_CAPABILITY_AVX512
436
+ #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
437
+ #else
438
+ #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
439
+ #endif
440
+ #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
441
+ #endif
442
+ } // namespace at::native
443
+
444
+ C10_CLANG_DIAGNOSTIC_POP()
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/Generator.h>
7
+ #include <ATen/ExpandUtils.h>
8
+ #include <ATen/Tensor.h>
9
+ #include <ATen/MemoryOverlap.h>
10
+ #include <ATen/NamedTensorUtils.h>
11
+ #include <ATen/native/Resize.h>
12
+ #include <ATen/native/TensorIterator.h>
13
+ #include <cmath>
14
+ #include <limits>
15
+ #include <optional>
16
+
17
+ #ifndef AT_PER_OPERATOR_HEADERS
18
+ #include <ATen/Functions.h>
19
+ #else
20
+ #include <ATen/ops/empty_like.h>
21
+ #include <ATen/ops/empty.h>
22
+ #include <ATen/ops/full.h>
23
+ #include <ATen/ops/view_as_real.h>
24
+ #endif
25
+
26
+ namespace at::native::templates {
27
+
28
+ // ==================================================== Random ========================================================
29
+
30
+ // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
31
+ // The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
32
+ // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
33
+ //
34
+ // auto actual = torch::empty({3, 3}, torch::half);
35
+ // actual.random_(0, 65504);
36
+ //
37
+ // If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
38
+ // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
39
+ // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
40
+ // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
41
+ // available number for torch::half dtype.
42
+ template<typename scalar_t>
43
+ int64_t update_from(int64_t from) {
44
+ static_assert(
45
+ std::is_floating_point<scalar_t>::value ||
46
+ std::is_same<scalar_t, at::Half>::value ||
47
+ std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
48
+ const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
49
+ if (from_plus_1 < from) {
50
+ int64_t from_ = std::abs(from + 1);
51
+ int n = 0;
52
+ while (from_ >>= 1) ++n;
53
+ // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
54
+ from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
55
+ }
56
+ return from;
57
+ }
58
+
59
+ template<typename scalar_t>
60
+ int64_t update_to(int64_t to) {
61
+ static_assert(
62
+ std::is_floating_point<scalar_t>::value ||
63
+ std::is_same<scalar_t, at::Half>::value ||
64
+ std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
65
+ const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
66
+ if (to_minus_1 >= to) {
67
+ int64_t to_ = std::abs(to - 1);
68
+ int n = 0;
69
+ while (to_ >>= 1) ++n;
70
+ // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
71
+ to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
72
+ }
73
+ return to;
74
+ }
75
+
76
+ // Return earlier for not invoking kernel.
77
+ // See https://github.com/pytorch/pytorch/issues/103418 for more details
78
+ #define CHECK_EMPTY_AND_RETURN(tensor) \
79
+ if (tensor.numel() == 0) { \
80
+ return tensor; \
81
+ }
82
+
83
+ template<template<typename> class random_kernel, typename RNG>
84
+ at::Tensor& random_impl(at::Tensor& self, std::optional<Generator> generator) {
85
+ CHECK_EMPTY_AND_RETURN(self);
86
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
87
+ random_kernel<RNG>()(iter, generator);
88
+ return self;
89
+ }
90
+
91
+ #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
92
+ TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
93
+
94
+ #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
95
+ if (var < -(1LL << digits) || var > (1LL << digits)) { \
96
+ TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
97
+ "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
98
+ "This warning will become an error in version 1.7 release, please fix the code in advance"); \
99
+ }
100
+
101
+ inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
102
+ const auto scalar_type = typeMetaToScalarType(dtype);
103
+ if (isFloatingType(scalar_type)) {
104
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
105
+ const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
106
+ const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
107
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
108
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
109
+
110
+ constexpr auto digits = std::numeric_limits<scalar_t>::digits;
111
+ WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
112
+ WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
113
+ });
114
+ } else if (scalar_type == kUInt64) {
115
+ // When you do a comparison between int64_t and uint64_t, the usual
116
+ // arithmetic conversions say that the int64_t value is promoted to
117
+ // unsigned. But this conversion wraps around: if I had -1 as my int64_t,
118
+ // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
119
+ // the right thing to do.
120
+ CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
121
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
122
+ } else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
123
+ AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
124
+ const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
125
+ const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
126
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
127
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
128
+ }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
129
+ } else {
130
+ TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
131
+ }
132
+ }
133
+
134
+ template<template<typename> class random_from_to_kernel, typename RNG>
135
+ at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional<int64_t> to_opt, std::optional<Generator> generator) {
136
+ uint64_t range = 0;
137
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
138
+ if (to_opt.has_value()) {
139
+ // [from, to)
140
+ int64_t to = *to_opt;
141
+ TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
142
+ if (isFloatingType(iter.dtype())) {
143
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
144
+ from = update_from<scalar_t>(from);
145
+ to = update_to<scalar_t>(to);
146
+ TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
147
+ });
148
+ }
149
+ check_from_to_in_range(from, to - 1, self.dtype());
150
+ CHECK_EMPTY_AND_RETURN(self);
151
+ range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
152
+ random_from_to_kernel<RNG>()(iter, range, from, generator);
153
+ } else if (from != std::numeric_limits<int64_t>::lowest()) {
154
+ // [from, std::numeric_limits<int64_t>::max()]
155
+ int64_t to_inc = 0;
156
+ if (isFloatingType(iter.dtype())) {
157
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
158
+ constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
159
+ to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
160
+ from = update_from<scalar_t>(from);
161
+ TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
162
+ });
163
+ } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
164
+ AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
165
+ if constexpr (std::is_same_v<scalar_t, bool>) {
166
+ to_inc = static_cast<int64_t>(true);
167
+ } else {
168
+ to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
169
+ }
170
+ }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
171
+ } else {
172
+ TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
173
+ }
174
+ check_from_to_in_range(from, to_inc, self.dtype());
175
+ CHECK_EMPTY_AND_RETURN(self);
176
+ range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
177
+ random_from_to_kernel<RNG>()(iter, range, from, generator);
178
+ } else {
179
+ // [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
180
+ // range = 2^64
181
+ CHECK_EMPTY_AND_RETURN(self);
182
+ random_from_to_kernel<RNG>()(iter, generator);
183
+ }
184
+ return self;
185
+ }
186
+
187
+ // ==================================================== Normal ========================================================
188
+
189
+ #define CHECK_NORMAL_TENSOR_STD(std) \
190
+ do { \
191
+ TORCH_CHECK( \
192
+ !std.is_complex(), \
193
+ "normal expects standard deviation to be non-complex"); \
194
+ TORCH_CHECK( \
195
+ std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
196
+ "normal expects all elements of std >= 0.0"); \
197
+ } while (0)
198
+
199
+ #define CHECK_NORMAL_STD(std) \
200
+ TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
201
+
202
+ template<template<typename> class normal_kernel, typename RNG>
203
+ Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
204
+ CHECK_NORMAL_STD(std);
205
+ CHECK_EMPTY_AND_RETURN(self);
206
+
207
+ if (self.is_complex()) {
208
+ auto float_tensor = at::view_as_real(self);
209
+ // variance for normal distribution of the real and imaginary values
210
+ // is half of the input variance
211
+ normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
212
+ } else {
213
+ normal_kernel<RNG>()(self, mean, std, gen);
214
+ }
215
+ return self;
216
+ }
217
+
218
+ template<template<typename> class normal_kernel, typename RNG>
219
+ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional<Generator> gen) {
220
+ CHECK_NORMAL_STD(std);
221
+ auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
222
+ auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
223
+ at::native::resize_output(output, shape);
224
+ normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
225
+ output.add_(mean);
226
+ return output;
227
+ }
228
+
229
+ template<template<typename> class normal_kernel, typename RNG>
230
+ Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional<Generator> gen) {
231
+ CHECK_NORMAL_TENSOR_STD(std);
232
+ auto mean_tensor = at::full({}, mean, output.options());
233
+ auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
234
+ at::native::resize_output(output, shape);
235
+ normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
236
+ // CUDA NB: addcmul_out copies the tensor to be added into the output.
237
+ // The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
238
+ // The third argument is not a constant reference and hence the samples in output are overwritten.
239
+ // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
240
+ output.mul_(std).add_(mean_tensor);
241
+ return output;
242
+ }
243
+
244
+ template<template<typename> class normal_kernel, typename RNG>
245
+ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
246
+ CHECK_NORMAL_TENSOR_STD(std);
247
+ auto shape = at::infer_size(mean.sizes(), std.sizes());
248
+ at::native::resize_output(output, shape);
249
+ normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
250
+ // CUDA NB: addcmul_out copies the tensor to be added into the output.
251
+ // The previous function here was addcmul_out(output, mean, output, std, 1);
252
+ // The third argument is not a constant reference and hence the samples in output are overwritten.
253
+ // Consequently, the computation performed is mean + mean * std instead of mean + output * std
254
+ output.mul_(std).add_(mean);
255
+ return output;
256
+ }
257
+
258
+ template<template<typename> class normal_kernel, typename RNG>
259
+ Tensor normal_impl(const Tensor& mean, double std, std::optional<Generator> gen) {
260
+ CHECK_NORMAL_STD(std);
261
+ Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
262
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
263
+ return ret;
264
+ }
265
+
266
+ template<template<typename> class normal_kernel, typename RNG>
267
+ Tensor normal_impl(double mean, const Tensor& std, std::optional<Generator> gen) {
268
+ CHECK_NORMAL_TENSOR_STD(std);
269
+ Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
270
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
271
+ return ret;
272
+ }
273
+
274
+ template<template<typename> class normal_kernel, typename RNG>
275
+ Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
276
+ CHECK_NORMAL_TENSOR_STD(std);
277
+ auto shape = at::infer_size(mean.sizes(), std.sizes());
278
+ Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
279
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
280
+ return ret;
281
+ }
282
+
283
+ // ==================================================== Uniform =======================================================
284
+
285
+ template<template<typename> class uniform_kernel, typename RNG>
286
+ at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional<Generator> generator) {
287
+ if (self.is_complex()) {
288
+ CHECK_EMPTY_AND_RETURN(self);
289
+ auto float_tensor = at::view_as_real(self);
290
+ uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
291
+ } else {
292
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
293
+ [[maybe_unused]] const auto dtype = self.dtype();
294
+ const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
295
+ const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
296
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
297
+ CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
298
+ TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
299
+ TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
300
+ "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
301
+ ">::max(), but found to=", to, " and from=", from,
302
+ " which result in to-from to exceed the limit");
303
+ from = std::min(std::max(from, min), max);
304
+ to = std::max(std::min(to, max), min);
305
+ });
306
+ CHECK_EMPTY_AND_RETURN(self);
307
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
308
+ uniform_kernel<RNG>()(iter, from, to, generator);
309
+ }
310
+ return self;
311
+ }
312
+
313
+ // ================================================== LogNormal =======================================================
314
+
315
+ template<template<typename> class log_normal_kernel, typename RNG>
316
+ at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional<Generator> gen) {
317
+ TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
318
+ CHECK_EMPTY_AND_RETURN(self);
319
+ auto iter = TensorIterator::borrowing_nullary_op(self);
320
+ log_normal_kernel<RNG>()(iter, mean, std, gen);
321
+ return self;
322
+ }
323
+
324
+ // =================================================== Geometric ======================================================
325
+
326
+ template<template<typename> class geometric_kernel, typename RNG>
327
+ Tensor& geometric_impl_(Tensor& self, double p, std::optional<Generator> gen) {
328
+ TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
329
+ CHECK_EMPTY_AND_RETURN(self);
330
+ auto iter = TensorIterator::borrowing_nullary_op(self);
331
+ geometric_kernel<RNG>()(iter, p, gen);
332
+ return self;
333
+ }
334
+
335
+ // ================================================== Exponential =====================================================
336
+
337
+ template<template<typename> class exponential_kernel, typename RNG>
338
+ Tensor& exponential_impl_(Tensor& self, double lambda, std::optional<Generator> gen) {
339
+ TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
340
+ CHECK_EMPTY_AND_RETURN(self);
341
+ auto iter = TensorIterator::borrowing_nullary_op(self);
342
+ exponential_kernel<RNG>()(iter, lambda, gen);
343
+ return self;
344
+ }
345
+
346
+ // ==================================================== Cauchy ========================================================
347
+
348
+ template<template<typename> class cauchy_kernel, typename RNG>
349
+ Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
350
+ // TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
351
+ // the variance, squared sigma, is undefined for cauchy distribution
352
+ TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
353
+ TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
354
+ CHECK_EMPTY_AND_RETURN(self);
355
+ auto iter = TensorIterator::borrowing_nullary_op(self);
356
+ cauchy_kernel<RNG>()(iter, median, sigma, gen);
357
+ return self;
358
+ }
359
+
360
+ // ==================================================== Bernoulli =====================================================
361
+
362
+ template<template<typename> class bernoulli_tensor_kernel, typename RNG>
363
+ Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
364
+ CHECK_EMPTY_AND_RETURN(self);
365
+ NoNamesGuard guard;
366
+ at::assert_no_internal_overlap(self);
367
+ bernoulli_tensor_kernel<RNG>()(self, p_, gen);
368
+ return self;
369
+ }
370
+
371
+ template<template<typename> class bernoulli_scalar_kernel, typename RNG>
372
+ Tensor& bernoulli_impl_(Tensor& self, double p, std::optional<Generator> gen) {
373
+ TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
374
+ CHECK_EMPTY_AND_RETURN(self);
375
+ at::assert_no_internal_overlap(self);
376
+ bernoulli_scalar_kernel<RNG>()(self, p, gen);
377
+ return self;
378
+ }
379
+
380
+ template<template<typename> class bernoulli_tensor_kernel, typename RNG>
381
+ Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional<Generator> gen) {
382
+ // result.resize_as_(self) requires self to have same dtype as result, so we
383
+ // use resize_ instead.
384
+ // TODO: Fix resize_as_. See pytorch/pytorch#11665.
385
+ result.resize_(self.sizes());
386
+ bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
387
+ namedinference::propagate_names(result, self);
388
+ return result;
389
+ }
390
+
391
+ #undef CHECK_OUT_OF_BOUNDS
392
+ #undef WARN_OUT_OF_BOUNDS
393
+
394
+ } // namespace at::native::templates
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/Math.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/MathConstants.h>
6
+
7
+ // ROCM hcc doesn't work well with using std:: in kernel functions
8
+ #if defined(__CUDA_ARCH__)
9
+ #include <c10/cuda/CUDAMathCompat.h>
10
+ #define compat_exp c10::cuda::compat::exp
11
+ #define compat_ceil c10::cuda::compat::ceil
12
+ #define compat_floor c10::cuda::compat::floor
13
+ #define compat_log c10::cuda::compat::log
14
+ #define compat_pow c10::cuda::compat::pow
15
+ #define compat_sqrt c10::cuda::compat::sqrt
16
+ #define compat_tan c10::cuda::compat::tan
17
+ #define compat_abs c10::cuda::compat::abs
18
+ #define compat_log1p c10::cuda::compat::log1p
19
+ #elif defined(__HIPCC__)
20
+ #include <c10/hip/HIPMathCompat.h>
21
+ #define compat_exp c10::hip::compat::exp
22
+ #define compat_ceil c10::hip::compat::ceil
23
+ #define compat_floor c10::hip::compat::floor
24
+ #define compat_log c10::hip::compat::log
25
+ #define compat_pow c10::hip::compat::pow
26
+ #define compat_sqrt c10::hip::compat::sqrt
27
+ #define compat_tan c10::hip::compat::tan
28
+ #define compat_abs c10::hip::compat::abs
29
+ #define compat_log1p c10::hip::compat::log1p
30
+ #else
31
+ #define compat_exp std::exp
32
+ #define compat_ceil std::ceil
33
+ #define compat_floor std::floor
34
+ #define compat_log std::log
35
+ #define compat_pow std::pow
36
+ #define compat_sqrt std::sqrt
37
+ #define compat_tan std::tan
38
+ #define compat_abs std::abs
39
+ #define compat_log1p std::log1p
40
+ #endif
41
+
42
+ namespace {
43
+
44
+ #if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
45
+ // we cannot use std::isnan directly due to some incompatibility of
46
+ // gcc constexpr'ing and nvcc
47
+ using std::isnan;
48
+ #endif
49
+
50
+ // Here sampler_t should be function type scalar_t(void). For gpu
51
+ // "sampler" is a device function, but since ROCM doesn't have
52
+ // equivalent to nvstd::function, we use a template type parameter to
53
+ // capture it.
54
+ template<typename scalar_t, typename sampler_t>
55
+ struct BaseSampler {
56
+ sampler_t sampler;
57
+ C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
58
+ C10_DEVICE scalar_t sample() {
59
+ return sampler();
60
+ }
61
+ };
62
+
63
+ // The function `sample_gamma` is
64
+ // is adapted from Numpy's distributions.c implementation.
65
+ // It is MIT licensed, so here is the copyright:
66
+
67
+ /* Copyright 2005 Robert Kern (robert.kern@gmail.com)
68
+ *
69
+ * Permission is hereby granted, free of charge, to any person obtaining a
70
+ * copy of this software and associated documentation files (the
71
+ * "Software"), to deal in the Software without restriction, including
72
+ * without limitation the rights to use, copy, modify, merge, publish,
73
+ * distribute, sublicense, and/or sell copies of the Software, and to
74
+ * permit persons to whom the Software is furnished to do so, subject to
75
+ * the following conditions:
76
+ *
77
+ * The above copyright notice and this permission notice shall be included
78
+ * in all copies or substantial portions of the Software.
79
+ *
80
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
81
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
82
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
83
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
84
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
85
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
86
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
87
+ */
88
+
89
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
90
+ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
91
+ accscalar_t scale = 1.0f;
92
+
93
+ // Boost alpha for higher acceptance probability.
94
+ if (alpha < 1.0f) {
95
+ if (alpha == 0.f) return 0.f;
96
+ scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
97
+ alpha += 1.0f;
98
+ }
99
+
100
+ // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
101
+ // doi:10.1145/358407.358414
102
+ const accscalar_t d = alpha - 1.0f / 3.0f;
103
+ const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
104
+ for (;;) {
105
+ accscalar_t x, y;
106
+ do {
107
+ x = standard_normal.sample();
108
+ y = 1.0f + c * x;
109
+ } while (y <= 0);
110
+ const accscalar_t v = y * y * y;
111
+ const accscalar_t u = 1 - standard_uniform.sample();
112
+ const accscalar_t xx = x * x;
113
+ if (u < 1.0f - 0.0331f * xx * xx)
114
+ return static_cast<scalar_t>(scale * d * v);
115
+ if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
116
+ return static_cast<scalar_t>(scale * d * v);
117
+ }
118
+ }
119
+
120
+ /* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
121
+ * from TensorFlow's random_binomial_op.cc implementation. That code is under
122
+ * copyright: 2019 The TensorFlow Authors.
123
+ *
124
+ * It was released under the Apache License, Version 2.0 (the "License"), available at:
125
+ * http://www.apache.org/licenses/LICENSE-2.0
126
+ */
127
+
128
+ template<typename scalar_t>
129
+ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
130
+ const static scalar_t kTailValues[] = {
131
+ 0.0810614667953272,
132
+ 0.0413406959554092,
133
+ 0.0276779256849983,
134
+ 0.02079067210376509,
135
+ 0.0166446911898211,
136
+ 0.0138761288230707,
137
+ 0.0118967099458917,
138
+ 0.0104112652619720,
139
+ 0.00925546218271273,
140
+ 0.00833056343336287
141
+ };
142
+ if (k <= 9) {
143
+ return kTailValues[static_cast<size_t>(k)];
144
+ }
145
+ scalar_t kp1sq = (k + 1) * (k + 1);
146
+ return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
147
+ }
148
+
149
+
150
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
151
+ C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
152
+ accscalar_t U;
153
+ accscalar_t geom_sum = 0;
154
+ scalar_t num_geom = 0;
155
+
156
+ accscalar_t logprob = compat_log1p(-prob);
157
+
158
+ while (1) {
159
+ U = standard_uniform.sample();
160
+ accscalar_t geom = compat_ceil(compat_log(U) / logprob);
161
+ geom_sum += geom;
162
+ if (geom_sum > count) {
163
+ break;
164
+ }
165
+ num_geom = num_geom + 1;
166
+ }
167
+ return num_geom;
168
+ }
169
+
170
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
171
+ C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
172
+ scalar_t k;
173
+ accscalar_t U, V, us;
174
+
175
+ // This is spq in the paper.
176
+ const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
177
+
178
+ // Other coefficients for Transformed Rejection sampling.
179
+ const accscalar_t b = 1.15 + 2.53 * stddev;
180
+ const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
181
+ const accscalar_t c = count * prob + 0.5;
182
+ const accscalar_t v_r = 0.92 - 4.2 / b;
183
+ const accscalar_t r = prob / (1 - prob);
184
+
185
+ const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
186
+ const accscalar_t m = compat_floor((count + 1) * prob);
187
+
188
+ while (1) {
189
+ U = standard_uniform.sample() - 0.5;
190
+ V = standard_uniform.sample();
191
+
192
+ us = 0.5 - compat_abs(U);
193
+ k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
194
+
195
+ // Reject non-sensical answers.
196
+ if (k < 0 || k > count) {
197
+ continue;
198
+ }
199
+ // Region for which the box is tight, and we can return our calculated value.
200
+ // This should happen 0.86 * v_r times. In the limit as n * p is large,
201
+ // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
202
+ if (us >= 0.07 && V <= v_r) {
203
+ return k;
204
+ }
205
+
206
+ // This deviates from Hormann's BTRS algorithm, as there is a log missing.
207
+ // For all (u, v) pairs outside of the bounding box, this calculates the
208
+ // transformed-reject ratio.
209
+ V = compat_log(V * alpha / (a / (us * us) + b));
210
+ accscalar_t upperbound =
211
+ ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
212
+ (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
213
+ (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
214
+ stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
215
+ stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
216
+
217
+ if (V <= upperbound) {
218
+ return k;
219
+ }
220
+ }
221
+ }
222
+
223
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
224
+ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
225
+ if (count <= 0.0 || prob <= 0.0) {
226
+ return 0;
227
+ } else if (prob >= 1.0) {
228
+ return count;
229
+ } else if (prob <= 0.5) {
230
+ if (count * prob >= 10.0) {
231
+ // btrs
232
+ return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
233
+ } else {
234
+ // binomial inversion
235
+ return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
236
+ }
237
+ } else if (prob > 0.5) {
238
+ scalar_t qprob = 1.0 - prob;
239
+ if (count * qprob >= 10.0) {
240
+ // btrs
241
+ return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
242
+ } else {
243
+ // count - binomial inversion
244
+ return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
245
+ }
246
+ } else {
247
+ // prob is nan?
248
+ return static_cast<scalar_t>(NAN);
249
+ }
250
+ }
251
+
252
+ /*
253
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
254
+ * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
255
+ */
256
+ template<typename scalar_t, typename accscalar_t>
257
+ C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
258
+ constexpr accscalar_t PSI_10 = 2.25175258906672110764;
259
+ if (x == 0) {
260
+ return INFINITY;
261
+ }
262
+ accscalar_t additional_summand = 0;
263
+ int x_is_integer = x == compat_floor(x);
264
+ if (x < 0) {
265
+ if (x_is_integer) {
266
+ return INFINITY;
267
+ }
268
+ // it is more standard to write this as recursion, but
269
+ // nvcc does not like that
270
+ additional_summand = -c10::pi<scalar_t> /
271
+ compat_tan(c10::pi<scalar_t> * x);
272
+ x = 1 - x;
273
+ }
274
+
275
+ // Push x to be >= 10
276
+ accscalar_t result = 0;
277
+ while (x < 10) {
278
+ result -= 1 / x;
279
+ x += 1;
280
+ }
281
+ if (x == 10) {
282
+ return result + PSI_10 + additional_summand;
283
+ }
284
+
285
+ // Compute asymptotic digamma
286
+ static const accscalar_t A[] = {
287
+ 8.33333333333333333333E-2,
288
+ -2.10927960927960927961E-2,
289
+ 7.57575757575757575758E-3,
290
+ -4.16666666666666666667E-3,
291
+ 3.96825396825396825397E-3,
292
+ -8.33333333333333333333E-3,
293
+ 8.33333333333333333333E-2,
294
+ };
295
+
296
+ accscalar_t y = 0;
297
+ if (x < 1.0e17f) {
298
+ accscalar_t z = 1.0 / (x * x);
299
+ y = z * polevl<accscalar_t>(z, A, 6);
300
+ }
301
+ return static_cast<scalar_t>(
302
+ result + compat_log(x) - (0.5f / x) - y + additional_summand);
303
+ }
304
+
305
+ // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
306
+ // for random number x drawn from a standard Gamma distribution Gamma(alpha).
307
+ template <typename scalar_t, typename accscalar_t>
308
+ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
309
+ // Use a Taylor series expansion for small x.
310
+ accscalar_t x = static_cast<accscalar_t>(x_);
311
+ accscalar_t alpha = static_cast<accscalar_t>(alpha_);
312
+ if (x < 0.8f) {
313
+ accscalar_t numer = 1;
314
+ accscalar_t denom = alpha;
315
+ auto series1 = numer / denom;
316
+ auto series2 = numer / (denom * denom);
317
+ for (int i = 1; i <= 5; ++i) {
318
+ numer *= -x / static_cast<accscalar_t>(i);
319
+ denom += 1;
320
+ series1 += numer / denom;
321
+ series2 += numer / (denom * denom);
322
+ }
323
+ const auto pow_x_alpha = compat_pow(x, alpha);
324
+ const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
325
+ const auto gamma_cdf = pow_x_alpha * series1;
326
+ const auto gamma_cdf_alpha =
327
+ (compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
328
+ gamma_cdf -
329
+ pow_x_alpha * series2;
330
+ const auto result = -gamma_cdf_alpha / gamma_pdf;
331
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
332
+ }
333
+
334
+ // Use a Rice saddle point expansion for large alpha.
335
+ if (alpha > 8.0f) {
336
+ if (0.9f * alpha <= x && x <= 1.1f * alpha) {
337
+ const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
338
+ const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
339
+ - 65 * x * x / alpha + alpha * (107 + 3600 * x);
340
+ const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
341
+ return static_cast<scalar_t>(numer_1 * numer_2 / denom);
342
+ }
343
+ const auto denom = compat_sqrt(8 * alpha);
344
+ const auto term2 = denom / (alpha - x);
345
+ const auto term3 = compat_pow(
346
+ x - alpha - alpha * compat_log(x / alpha),
347
+ static_cast<accscalar_t>(-1.5));
348
+ const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
349
+ const auto term1 = compat_log(x / alpha) * term23 -
350
+ compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
351
+ const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
352
+ const auto numer = x * term1;
353
+ return static_cast<scalar_t>(-stirling * numer / denom);
354
+ }
355
+
356
+ // Use a bivariate rational approximation to the reparameterized gradient.
357
+ const auto u = compat_log(x / alpha);
358
+ const auto v = compat_log(alpha);
359
+ static const accscalar_t coef_uv[3][8] = {
360
+ {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
361
+ 1, 0.32668115, 0.10406089, 0.0014179084},
362
+ {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
363
+ 0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
364
+ {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
365
+ 0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
366
+ };
367
+ accscalar_t coef_v[8];
368
+ for (int i = 0; i < 8; ++ i) {
369
+ coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
370
+ }
371
+ const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
372
+ const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
373
+ return static_cast<scalar_t>(compat_exp(p / q));
374
+ }
375
+
376
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
377
+ // Assumes x is close to zero and uses a Taylor expansion.
378
+ template <typename scalar_t, typename accscalar_t>
379
+ C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
380
+ const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
381
+ - digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
382
+ scalar_t numer = 1;
383
+ scalar_t series = numer / alpha * (factor + 1 / alpha);
384
+ for (int i = 1; i <= 10; ++i) {
385
+ scalar_t casted_i = static_cast<scalar_t>(i);
386
+ numer *= (casted_i - beta) * x / casted_i;
387
+ const scalar_t denom = alpha + casted_i;
388
+ series += numer / denom * (factor + 1 / denom);
389
+ }
390
+ const scalar_t result = x * compat_pow(1 - x, -beta) * series;
391
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
392
+ }
393
+
394
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
395
+ // Assumes x is close to zero and uses a Taylor expansion.
396
+ template <typename scalar_t, typename accscalar_t>
397
+ C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
398
+ const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
399
+ scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
400
+ for (int i = 1; i <= 8; ++i) {
401
+ scalar_t casted_i = static_cast<scalar_t>(i);
402
+ numer *= -x / casted_i;
403
+ dbetas = dbetas * (beta - casted_i) + betas;
404
+ betas = betas * (beta - casted_i);
405
+ series += numer / (alpha + casted_i) * (dbetas + factor * betas);
406
+ }
407
+ const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
408
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
409
+ }
410
+
411
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
412
+ // Assumes alpha and beta are both large and uses a Rice saddle point expansion.
413
+ // To ensure numerical stability, this computation is performed at higher precision.
414
+ template<typename scalar_t, typename accscalar_t>
415
+ C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
416
+ const accscalar_t total = alpha + beta;
417
+ const accscalar_t mean = alpha / total;
418
+ const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
419
+ if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
420
+ // Avoid the singularity at x = mean.
421
+ const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
422
+ (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
423
+ 3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
424
+ (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
425
+ 8 * (1 - x) * (135 * beta - 11)))));
426
+ const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
427
+ const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
428
+ return prefactor_num / (1 - x) * poly / prefactor_den;
429
+ }
430
+ const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
431
+ const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
432
+ * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
433
+ / (1 + 1 / (12 * total) + 1 / (288 * total * total));
434
+ const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
435
+ const accscalar_t axbx = alpha * (x - 1) + beta * x;
436
+ const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
437
+ const accscalar_t term1 = term1_num / term1_den;
438
+ const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
439
+ const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
440
+ const accscalar_t term3_den = beta * x + alpha * (x - 1);
441
+ const accscalar_t term3 = term3_num / term3_den;
442
+ const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
443
+ alpha * compat_log(alpha / (total * x));
444
+ const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
445
+ const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
446
+ return static_cast<scalar_t>(stirling * prefactor * term1234);
447
+ }
448
+
449
+ // Computes a scaled reparameterized gradient
450
+ // -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
451
+ // for random number x drawn from a Beta distribution Beta(alpha,beta).
452
+ // This function inputs total=alpha+beta to make it easy to implement
453
+ // Dirichlet reparameterized gradients in terms of Betas.
454
+ template<typename scalar_t, typename accscalar_t>
455
+ C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
456
+ accscalar_t x_ = static_cast<accscalar_t>(x);
457
+ accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
458
+ accscalar_t total_ = static_cast<accscalar_t>(total);
459
+
460
+ const scalar_t beta = total - alpha;
461
+ const accscalar_t beta_ = total_ - alpha_;
462
+ const scalar_t boundary = total * x * (1 - x);
463
+
464
+ // Use an asymptotic approximation for x close to 0.
465
+ if (x <= 0.5f && boundary < 2.5f) {
466
+ return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
467
+ }
468
+
469
+ // Use an asymptotic approximation for x close to 1.
470
+ if (x >= 0.5f && boundary < 0.75f) {
471
+ return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
472
+ }
473
+
474
+ // Use an asymptotic approximation when alpha and (total - alpha) are both large.
475
+ if (alpha > 6 && beta > 6) {
476
+ return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
477
+ }
478
+
479
+ // Use a rational correction to an analytic approximation.
480
+ static const accscalar_t c[2][3][3][4] = {
481
+ {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
482
+ {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
483
+ {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
484
+ {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
485
+ {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
486
+ {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
487
+ {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
488
+ {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
489
+ {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
490
+ {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
491
+ {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
492
+ {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
493
+ {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
494
+ {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
495
+ {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
496
+ {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
497
+ {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
498
+ {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
499
+ };
500
+ const accscalar_t u = compat_log(x_);
501
+ const accscalar_t a = compat_log(alpha_) - u;
502
+ const accscalar_t b = compat_log(total_) - a;
503
+ const accscalar_t pow_u[3] = {1, u, u * u};
504
+ const accscalar_t pow_a[3] = {1, a, a * a};
505
+ accscalar_t p = 0.0;
506
+ accscalar_t q = 0.0;
507
+ for (int i = 0; i < 3; ++i) {
508
+ for (int j = 0; j < 3; ++j) {
509
+ const accscalar_t ua = pow_u[i] * pow_a[j];
510
+ p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
511
+ q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
512
+ }
513
+ }
514
+ const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
515
+ return static_cast<scalar_t>(p / q * approx);
516
+ }
517
+
518
+ } // namespace
.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/Config.h>
3
+ #include <cstdint>
4
+
5
+ #ifdef USE_FBGEMM
6
+ #include <fbgemm/FbgemmEmbedding.h>
7
+ #endif
8
+
9
+ namespace at::native {
10
+
11
+ enum class EmbeddingBagMode {
12
+ SUM = 0,
13
+ MEAN = 1,
14
+ MAX = 2,
15
+ };
16
+
17
+ [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
18
+ return op1 == static_cast<int64_t>(op2);
19
+ }
20
+
21
+ [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
22
+ return !(op1 == op2);
23
+ }
24
+
25
+ void check_arguments(
26
+ const Tensor& weight,
27
+ const Tensor& indices,
28
+ const Tensor& offsets,
29
+ const int64_t mode,
30
+ const std::optional<Tensor>& per_sample_weights,
31
+ bool include_last_offset);
32
+
33
+ void make_bag_size_out(
34
+ Tensor& bag_size_out,
35
+ const Tensor& offsets,
36
+ const Tensor& indices,
37
+ const int64_t mode,
38
+ const bool include_last_offset,
39
+ const bool requires_grad);
40
+
41
+ void make_max_indices_out(
42
+ Tensor& max_indices_out,
43
+ const Tensor& weight,
44
+ const Tensor& indices,
45
+ const Tensor& offsets,
46
+ const Tensor& bag_size,
47
+ const int64_t mode,
48
+ bool include_last_offset);
49
+
50
+ void make_offset2bag_out(
51
+ Tensor& offset2bag,
52
+ Tensor& output,
53
+ const Tensor& weight,
54
+ const Tensor& indices,
55
+ const Tensor& offsets,
56
+ const int64_t mode,
57
+ const std::optional<Tensor>& per_sample_weights,
58
+ const int64_t padding_idx = -1);
59
+
60
+ #ifdef USE_FBGEMM
61
+
62
+ template<bool has_weight, typename TIndex, typename TData>
63
+ struct _CallbackAndBlockSize {
64
+ using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
65
+
66
+ int64_t blockSize = -1;
67
+ TCallback callback = nullptr;
68
+
69
+ static TCallback generateCallback(int64_t block_size) {
70
+ return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
71
+ block_size,
72
+ has_weight,
73
+ /* normalize_by_lengths */false,
74
+ /* prefetch */16,
75
+ /* is_weight_positional */false,
76
+ /* use_offsets */true);
77
+ }
78
+
79
+ _CallbackAndBlockSize() = default;
80
+
81
+ explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
82
+ : blockSize(maybe_block_size.value_or(-1))
83
+ , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
84
+ {}
85
+ };
86
+
87
+ template<typename... StorageMixins>
88
+ struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
89
+
90
+ _EmbeddingBagKernelCacheImpl() = default;
91
+ // use each of the mixins to store corresponding kernel and block size
92
+ explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
93
+ : StorageMixins(maybe_block_size)...
94
+ {}
95
+
96
+ // this method is thread safe (call sites may call from different threads)
97
+ template<bool has_weight, typename TIndex, typename TData>
98
+ typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
99
+ getCallback(int64_t block_size) const {
100
+ // if the cache doesn't store the kernel for the incoming block size
101
+ // (so it is different from the one stored in corresponding mixin)
102
+ // regenerate the kernel (not writing it into the cache so we avoid locks)
103
+ if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
104
+ return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
105
+ }
106
+ // else retrieve the cached kernel from the corresponding mixin
107
+ return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
108
+ }
109
+ };
110
+
111
+ // instantiate the cache with the list of storage mixins
112
+ // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
113
+ using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
114
+ _CallbackAndBlockSize<true, int32_t, float>,
115
+ _CallbackAndBlockSize<false, int32_t, float>,
116
+ _CallbackAndBlockSize<true, int64_t, float>,
117
+ _CallbackAndBlockSize<false, int64_t, float>,
118
+ _CallbackAndBlockSize<true, int32_t, unsigned short>,
119
+ _CallbackAndBlockSize<false, int32_t, unsigned short>,
120
+ _CallbackAndBlockSize<true, int64_t, unsigned short>,
121
+ _CallbackAndBlockSize<false, int64_t, unsigned short>>;
122
+ #else
123
+ struct _EmbeddingBagKernelCache {
124
+ explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
125
+ };
126
+ #endif
127
+
128
+ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
129
+ Tensor& bag_size, Tensor* max_indices,
130
+ const Tensor &weight, const Tensor &indices,
131
+ const Tensor &offsets, const int64_t mode = 0,
132
+ const std::optional<Tensor>& per_sample_weights = std::nullopt,
133
+ bool include_last_offset = false,
134
+ int64_t padding_idx = -1,
135
+ _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
136
+
137
+ void _embedding_bag_cpu_out(
138
+ at::Tensor& output,
139
+ at::Tensor& offset2bag,
140
+ at::Tensor& bag_size,
141
+ at::Tensor* p_max_indices,
142
+ const at::Tensor& weight,
143
+ const at::Tensor& indices,
144
+ const at::Tensor& offsets,
145
+ const bool scale_grad_by_freq,
146
+ const int64_t mode,
147
+ const bool sparse,
148
+ const std::optional<at::Tensor>& per_sample_weights,
149
+ const bool include_last_offset,
150
+ const std::optional<int64_t>& padding_idx,
151
+ _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
152
+
153
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Device.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/ScalarType.h>
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/utils/ParamsHash.h>
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/irange.h>
10
+
11
+ #ifndef AT_PER_OPERATOR_HEADERS
12
+ #include <ATen/NativeFunctions.h>
13
+ #else
14
+ #include <ATen/ops/result_type_native.h>
15
+ #endif
16
+
17
+ #include <unordered_map>
18
+ #include <vector>
19
+
20
+ namespace at::native {
21
+ namespace {
22
+ // Check if tensor list has either a boolean tensor or a integer tensor
23
+ inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
24
+ return std::any_of(
25
+ tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
26
+ return at::isIntegralType(t.scalar_type(), includeBool);
27
+ });
28
+ }
29
+ // check if tensor list has bool tensors
30
+ inline bool has_bool_tensor(TensorList tensors) {
31
+ return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
32
+ return t.scalar_type() == ScalarType::Bool;
33
+ });
34
+ }
35
+
36
+ // Check foreach API restrictions
37
+ // - Tensor lists must be non-empty.
38
+ // - All TensorLists and ScalarLists must have the same number of elements.
39
+ // - Corresponding tensors must have the same size.
40
+ inline void check_foreach_api_restrictions(TensorList tensors) {
41
+ TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
42
+ }
43
+
44
+ inline void check_foreach_api_restrictions(
45
+ TensorList tensors,
46
+ ArrayRef<Scalar> scalars) {
47
+ check_foreach_api_restrictions(tensors);
48
+ TORCH_CHECK(
49
+ tensors.size() == scalars.size(),
50
+ "Tensor list must have same number of elements as scalar list.");
51
+ }
52
+
53
+ inline void check_foreach_api_restrictions(
54
+ TensorList tensors1,
55
+ TensorList tensors2) {
56
+ TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
57
+ TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
58
+ TORCH_CHECK(
59
+ tensors1.size() == tensors2.size(),
60
+ "Tensor lists must have the same number of tensors, got ",
61
+ tensors1.size(),
62
+ " and ",
63
+ tensors2.size());
64
+ }
65
+
66
+ inline void check_foreach_api_restrictions(
67
+ TensorList tensors1,
68
+ TensorList tensors2,
69
+ TensorList tensors3) {
70
+ TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
71
+ TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
72
+ TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
73
+ TORCH_CHECK(
74
+ tensors1.size() == tensors2.size(),
75
+ "Tensor lists must have the same number of tensors, got ",
76
+ tensors1.size(),
77
+ " and ",
78
+ tensors2.size());
79
+ TORCH_CHECK(
80
+ tensors1.size() == tensors3.size(),
81
+ "Tensor lists must have the same number of tensors, got ",
82
+ tensors1.size(),
83
+ " and ",
84
+ tensors3.size());
85
+ }
86
+
87
+ inline void check_foreach_api_restrictions(
88
+ TensorList tensors1,
89
+ TensorList tensors2,
90
+ TensorList tensors3,
91
+ ArrayRef<Scalar> scalars) {
92
+ check_foreach_api_restrictions(tensors1, tensors2, tensors3);
93
+ TORCH_CHECK(
94
+ tensors1.size() == scalars.size(),
95
+ "Tensor list must have same number of elements as scalar list, got ",
96
+ tensors1.size(),
97
+ " and ",
98
+ scalars.size());
99
+ }
100
+
101
+ // Helper function called in check_fast_path_restrictions to check whether all
102
+ // corresponding tensors (aligning in index across the tensorLists) share the
103
+ // same device and dtype.
104
+ inline bool _check_tensors_share_device_and_dtype(
105
+ ArrayRef<TensorList> tensorLists,
106
+ const bool skip_dtype_check = false) {
107
+ const auto expected_dtype = tensorLists[0][0].dtype();
108
+ const auto expected_device = tensorLists[0][0].device();
109
+
110
+ auto is_tensor_okay = [&](const Tensor& tensor) {
111
+ return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
112
+ tensor.device() == expected_device && tensor.layout() == at::kStrided &&
113
+ tensor.is_non_overlapping_and_dense();
114
+ };
115
+
116
+ for (const auto& tensorList : tensorLists) {
117
+ for (const auto& tensor : tensorList) {
118
+ if (!is_tensor_okay(tensor)) {
119
+ return false;
120
+ }
121
+ }
122
+ }
123
+
124
+ return true;
125
+ }
126
+
127
+ // Helper function called in check_fast_path_restrictions to check if
128
+ // corresponding tensors in tensor lists have the same sizes and strides.
129
+ inline bool _check_tensors_share_sizes_and_strides(
130
+ ArrayRef<TensorList> tensorLists) {
131
+ auto is_diff_stride = [](const IntArrayRef& size,
132
+ const IntArrayRef& left_stride,
133
+ const IntArrayRef& right_stride) -> bool {
134
+ const size_t size_size = size.size();
135
+ for (const auto dim : c10::irange(size_size)) {
136
+ if (size[dim] == 1)
137
+ continue;
138
+ if (left_stride[dim] != right_stride[dim]) {
139
+ return true;
140
+ }
141
+ }
142
+ return false;
143
+ };
144
+ for (const auto i : c10::irange(1, tensorLists.size())) {
145
+ for (const auto j : c10::irange(tensorLists[0].size())) {
146
+ if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
147
+ is_diff_stride(
148
+ tensorLists[0][j].sizes(),
149
+ tensorLists[0][j].strides(),
150
+ tensorLists[i][j].strides())) {
151
+ return false;
152
+ }
153
+ }
154
+ }
155
+
156
+ return true;
157
+ }
158
+
159
+ // Helper function called in check_fast_path_restrictions to check whether
160
+ // all tensors type promote properly with the scalars in scalarList. This
161
+ // function assumes that _check_tensors_share_device_and_dtype has already been
162
+ // called so that all corresponding tensors in tensorLists have the same dtype.
163
+ // Then, it is sufficient to check the type promotion with just one tensorList.
164
+ inline bool _check_tensors_do_type_promotion_with_scalars(
165
+ TensorList tensorList,
166
+ ArrayRef<Scalar> scalarList = {},
167
+ bool does_op_promote_integer_inputs_to_float = false) {
168
+ for (const auto i : c10::irange(tensorList.size())) {
169
+ // For division, integer inputs will result in float.
170
+ if (does_op_promote_integer_inputs_to_float) {
171
+ if (at::isIntegralType(
172
+ tensorList[i].scalar_type(), /*includeBool*/ true)) {
173
+ return false;
174
+ }
175
+ }
176
+ if (!scalarList.empty()) {
177
+ const auto& scalar =
178
+ scalarList.size() == 1 ? scalarList[0] : scalarList[i];
179
+ const auto& tensor = tensorList[i];
180
+ // note(mkozuki): This check might be responsible for
181
+ // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
182
+ if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
183
+ return false;
184
+ }
185
+ }
186
+ }
187
+
188
+ return true;
189
+ }
190
+
191
+ // To go via 'fast' path, several conditions must be satisfied
192
+ // - All tensors in all lists must have the same dtype.
193
+ // - All tensors must be on the same device
194
+ // - All tensors must have strided layout
195
+ // - All tensors must be non-overlapping and dense
196
+ // - Resulting tensor must have the same dtype as the input one
197
+
198
+ // [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
199
+ // ``does_op_promote_integer_inputs_to_float=true`` means that the result of
200
+ // the op will be float even if inputs are integer or boolean, which
201
+ // currently fast path does not support. In short, this flag, when
202
+ // turned on, gatekeeps the op from going down the fastpath.
203
+
204
+ // Please, make sure to call check_foreach_api_restrictions before calling this
205
+ // method. There is a set of preconditions that have to be satisfied.
206
+ inline bool check_fast_path_restrictions(
207
+ ArrayRef<TensorList> tensorLists,
208
+ ArrayRef<Scalar> scalarList = {},
209
+ bool does_op_promote_integer_inputs_to_float = false) {
210
+ return _check_tensors_share_device_and_dtype(tensorLists) &&
211
+ _check_tensors_share_sizes_and_strides(tensorLists) &&
212
+ _check_tensors_do_type_promotion_with_scalars(
213
+ tensorLists[0],
214
+ scalarList,
215
+ does_op_promote_integer_inputs_to_float);
216
+ }
217
+
218
+ inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
219
+ const Tensor& scalarList_,
220
+ int64_t expect_length) {
221
+ std::vector<c10::Scalar> scalarList;
222
+ TORCH_CHECK(
223
+ scalarList_.device() == c10::kCPU,
224
+ "Expected scalars to be on CPU, got ",
225
+ scalarList_.device(),
226
+ " instead.");
227
+ TORCH_CHECK(
228
+ scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
229
+ TORCH_CHECK(
230
+ scalarList_.dim() == 1,
231
+ "Expected packed scalar Tensor to be of dimension 1. Got ",
232
+ scalarList_.dim(),
233
+ " instead.");
234
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
235
+ kComplexHalf,
236
+ kHalf,
237
+ kBool,
238
+ kBFloat16,
239
+ scalarList_.scalar_type(),
240
+ "convert_tensor_to_scalar_list",
241
+ [&]() {
242
+ const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
243
+ TORCH_CHECK(
244
+ (expect_length == scalarList_.size(0)),
245
+ "Expected length of scalars to match input of length ",
246
+ expect_length,
247
+ " but got ",
248
+ scalarList_.size(0),
249
+ " instead.");
250
+ for (int64_t i = 0; i < scalarList_.size(0); i++) {
251
+ scalarList.emplace_back(scalar_data[i]);
252
+ }
253
+ });
254
+ return scalarList;
255
+ }
256
+
257
+ // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
258
+ inline bool can_use_fast_route(
259
+ ArrayRef<TensorList> tensorLists,
260
+ ArrayRef<Scalar> scalarList = {},
261
+ bool does_op_promote_integer_inputs_to_float = false) {
262
+ return check_fast_path_restrictions(
263
+ tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
264
+ }
265
+
266
+ // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
267
+ inline bool can_use_fast_route(
268
+ TensorList tensors1,
269
+ TensorList tensors2,
270
+ bool does_op_promote_integer_inputs_to_float = false) {
271
+ return can_use_fast_route(
272
+ {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
273
+ }
274
+
275
+ using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
276
+ using IndicesT = std::vector<size_t>;
277
+ using nested_optional_tensorvec_t =
278
+ std::vector<std::vector<std::optional<at::Tensor>>>;
279
+ using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
280
+ using FlatMap = std::unordered_map<
281
+ DeviceDtypeKey,
282
+ TensorsAndIndicesT,
283
+ ParamsHash<DeviceDtypeKey>>;
284
+
285
+ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
286
+ const nested_optional_tensorvec_t& nested_tensorlist,
287
+ const bool with_indices) {
288
+ FlatMap grouped_tensors_with_indices;
289
+
290
+ TORCH_CHECK(!nested_tensorlist.empty());
291
+ TORCH_CHECK(!nested_tensorlist[0].empty());
292
+ const auto num_lists = nested_tensorlist.size();
293
+ const auto num_tensors = nested_tensorlist[0].size();
294
+
295
+ TORCH_CHECK(std::all_of(
296
+ nested_tensorlist.cbegin(),
297
+ nested_tensorlist.cend(),
298
+ [&](const auto& tensorlist) -> bool {
299
+ // note(crcrpar): Allow empty tensorlists following
300
+ // ref:
301
+ // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
302
+ return tensorlist.size() == num_tensors || tensorlist.size() == 0;
303
+ }));
304
+
305
+ for (const auto& tensor_index : c10::irange(num_tensors)) {
306
+ const auto key = [&]() -> DeviceDtypeKey {
307
+ const auto t = nested_tensorlist[0][tensor_index];
308
+ TORCH_CHECK(
309
+ t.has_value(),
310
+ "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
311
+ "the ",
312
+ tensor_index,
313
+ "-th Tensor is not.");
314
+ return {t->device(), t->scalar_type()};
315
+ }();
316
+ TORCH_CHECK(
317
+ std::all_of(
318
+ nested_tensorlist.cbegin(),
319
+ nested_tensorlist.cend(),
320
+ [&](const auto& tensorlist) -> bool {
321
+ if (tensorlist.size() == 0) {
322
+ return true;
323
+ }
324
+ const auto& tensor = tensorlist[tensor_index];
325
+ // note(crcrpar): Currently the scope of this function is
326
+ // optimizers so there could be `state_steps` and other scalars
327
+ // whose elements are float tensors no matter what the parameter's
328
+ // dtype is.
329
+ if (!tensor.has_value()) {
330
+ return true;
331
+ } else {
332
+ const auto s = tensor->scalar_type();
333
+ const auto d = tensor->device();
334
+ // Note: `step` or `state_step` is float32 by default.
335
+ if (key.first == d) {
336
+ return key.second == s || s == at::ScalarType::Float ||
337
+ s == at::ScalarType::Double;
338
+ } else if (d.is_cpu()) {
339
+ // note(crcrpar): There are some test cases (e.g.
340
+ // TestOptim::test_adam) where state_steps are on CPU and the
341
+ // others are on CUDA. Currently a state_step Tensor has the
342
+ // dtype of float.
343
+ return s == at::ScalarType::Float ||
344
+ s == at::ScalarType::Double;
345
+ } else {
346
+ return false;
347
+ }
348
+ }
349
+ }),
350
+ "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
351
+ if (!grouped_tensors_with_indices.count(key)) {
352
+ grouped_tensors_with_indices.insert(
353
+ {key,
354
+ TensorsAndIndicesT{
355
+ [&]() -> nested_optional_tensorvec_t {
356
+ nested_optional_tensorvec_t nested_tensorvec;
357
+ nested_tensorvec.reserve(num_lists);
358
+ for (const auto& i : c10::irange(num_lists)) {
359
+ std::vector<std::optional<at::Tensor>> tensors;
360
+ if (!nested_tensorlist[i].empty()) {
361
+ // NB: num_tensors is the max possible length for any of
362
+ // the inner lists of tensor references. Reserving the max
363
+ // trades memory for perf. This should not have significant
364
+ // impact.
365
+ tensors.reserve(num_tensors);
366
+ }
367
+ nested_tensorvec.emplace_back(tensors);
368
+ }
369
+ return nested_tensorvec;
370
+ }(),
371
+ [&]() -> IndicesT {
372
+ if (!with_indices) {
373
+ return {};
374
+ } else {
375
+ IndicesT indices;
376
+ indices.reserve(num_tensors);
377
+ return indices;
378
+ }
379
+ }()}});
380
+ }
381
+ for (const auto& list_index : c10::irange(num_lists)) {
382
+ if (!nested_tensorlist[list_index].empty()) {
383
+ grouped_tensors_with_indices[key].first[list_index].emplace_back(
384
+ nested_tensorlist[list_index][tensor_index]);
385
+ }
386
+ }
387
+ if (with_indices) {
388
+ grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
389
+ }
390
+ }
391
+
392
+ return grouped_tensors_with_indices;
393
+ }
394
+
395
+ } // namespace
396
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdagrad.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at::native {
5
+
6
+ using fused_adagrad_fn = void (*)(
7
+ const at::Tensor& param,
8
+ const at::Tensor& grad,
9
+ const at::Tensor& state_sum,
10
+ const at::Tensor& state_step,
11
+ const double lr,
12
+ const double lr_decay,
13
+ const double weight_decay,
14
+ const double eps,
15
+ const bool maximize,
16
+ const float* grad_scale_ptr);
17
+
18
+ DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
19
+
20
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdam.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at::native {
5
+
6
+ enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
7
+
8
+ using fused_adam_fn = void (*)(
9
+ const at::Tensor& param,
10
+ const at::Tensor& grad,
11
+ const at::Tensor& exp_avg,
12
+ const at::Tensor& exp_avg_sq,
13
+ const at::Tensor& max_exp_avg_sq,
14
+ const at::Tensor& state_step,
15
+ const double lr,
16
+ const double beta1,
17
+ const double beta2,
18
+ const double weight_decay,
19
+ const double eps,
20
+ const bool amsgrad,
21
+ const bool maximize,
22
+ const float* grad_scale_ptr,
23
+ const ADAM_MODE);
24
+
25
+ DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
26
+
27
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedSGD.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at::native {
5
+
6
+ using fused_sgd_fn = void (*)(
7
+ const at::Tensor& param,
8
+ const at::Tensor& grad,
9
+ const at::Tensor& momentum_buffer,
10
+ const double weight_decay,
11
+ const double momentum,
12
+ const double lr,
13
+ const double dampening,
14
+ const bool nesterov,
15
+ const bool maximize,
16
+ const bool is_first_step,
17
+ const float* grad_scale_ptr);
18
+
19
+ DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
20
+
21
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // See NOTE: [Tensor vs. TensorBase]
4
+ // https://github.com/pytorch/pytorch/pull/66979
5
+ #include <ATen/core/TensorBase.h>
6
+ #include <ATen/native/TensorProperties.h>
7
+ #include <ATen/native/CanUse32BitIndexMath.h>
8
+
9
+ namespace at::native {
10
+
11
+ namespace detail {
12
+
13
+ enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
14
+ enum class GridSamplerPadding {Zeros, Border, Reflection};
15
+
16
+ } // namespace detail
17
+
18
+ using detail::GridSamplerInterpolation;
19
+ using detail::GridSamplerPadding;
20
+
21
+ // See NOTE [ grid_sampler Native Functions ].
22
+ inline void check_grid_sampler_common(
23
+ const TensorBase& input,
24
+ const TensorBase& grid
25
+ ) {
26
+ auto input_opt = input.options();
27
+ auto grid_opt = grid.options();
28
+
29
+ TORCH_CHECK(
30
+ input.defined(),
31
+ "grid_sampler(): expected input to not be undefined");
32
+ TORCH_CHECK(
33
+ grid.defined(),
34
+ "grid_sampler(): expected grid to not be undefined");
35
+ TORCH_CHECK(
36
+ input_opt.device() == grid_opt.device(),
37
+ "grid_sampler(): expected input and grid to be on same device, but input "
38
+ "is on ", input_opt.device(), " and grid is on ", grid_opt.device());
39
+ TORCH_CHECK(
40
+ input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
41
+ "grid_sampler(): expected input and grid to have torch.strided layout, but "
42
+ "input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
43
+ TORCH_CHECK(
44
+ input.size(0) == grid.size(0),
45
+ "grid_sampler(): expected grid and input to have same batch size, but got "
46
+ "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
47
+ TORCH_CHECK(
48
+ grid.size(-1) == input.dim() - 2,
49
+ "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
50
+ "dimension, but got grid with sizes ", grid.sizes());
51
+
52
+ for (const auto i : c10::irange(2, input.dim())) {
53
+ TORCH_CHECK(input.size(i) > 0,
54
+ "grid_sampler(): expected input to have non-empty spatial dimensions, "
55
+ "but input has sizes ", input.sizes(), " with dimension ", i, " being "
56
+ "empty");
57
+ }
58
+ }
59
+
60
+ // See NOTE [ grid_sampler Native Functions ].
61
+ inline void check_grid_sampler_2d(
62
+ const TensorBase& input,
63
+ const TensorBase& grid
64
+ ) {
65
+ TORCH_CHECK(
66
+ input.dim() == 4 && input.dim() == grid.dim(),
67
+ "grid_sampler(): expected 4D input and grid with same number of "
68
+ "dimensions, but got input with sizes ", input.sizes(),
69
+ " and grid with sizes ", grid.sizes());
70
+ }
71
+
72
+ // See NOTE [ grid_sampler Native Functions ].
73
+ inline void check_grid_sampler_3d(
74
+ const TensorBase& input,
75
+ const TensorBase& grid,
76
+ int64_t interpolation_mode
77
+ ) {
78
+ TORCH_CHECK(
79
+ input.dim() == 5 && input.dim() == grid.dim(),
80
+ "grid_sampler(): expected 5D input and grid with same number of "
81
+ "dimensions, but got input with sizes ", input.sizes(),
82
+ " and grid with sizes ", grid.sizes());
83
+ TORCH_CHECK(
84
+ !(input.dim() == 5 &&
85
+ static_cast<GridSamplerInterpolation>(interpolation_mode) ==
86
+ GridSamplerInterpolation::Bicubic),
87
+ "grid_sampler(): bicubic interpolation only supports 4D input");
88
+ }
89
+
90
+ // See NOTE [ grid_sampler Native Functions ].
91
+ // cudnn does not support inputs larger than 1024.
92
+ inline bool cond_cudnn_grid_sampler(
93
+ const TensorBase& input,
94
+ const TensorBase& grid
95
+ ) {
96
+ return (
97
+ at::native::cudnn_is_acceptable(input) &&
98
+ at::native::cudnn_is_acceptable(grid) &&
99
+ at::native::canUse32BitIndexMath(input) &&
100
+ at::native::canUse32BitIndexMath(grid) &&
101
+ input.dim() == 4 &&
102
+ input.sym_size(1) <= 1024);
103
+ }
104
+
105
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using histogramdd_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&);
9
+ using histogramdd_linear_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
10
+ using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
11
+
12
+ DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
13
+ DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
14
+ DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
15
+
16
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+ #include <c10/util/ArrayRef.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+ class TensorBase;
8
+ struct TensorIterator;
9
+ struct TensorIteratorBase;
10
+ }
11
+
12
+ namespace c10 {
13
+ class Scalar;
14
+ }
15
+
16
+ namespace at::native {
17
+
18
+ using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
19
+ using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
20
+ using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
21
+ using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
22
+ using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
23
+ using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
24
+ using flip_fn = void(*)(TensorIterator &, const bool);
25
+ using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
26
+ using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
27
+ using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
28
+
29
+ DECLARE_DISPATCH(index_fn, index_stub);
30
+ DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
31
+ DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
32
+ DECLARE_DISPATCH(index_put_fn, index_put_stub);
33
+ DECLARE_DISPATCH(put_fn, put_stub);
34
+ DECLARE_DISPATCH(take_fn, take_stub);
35
+ DECLARE_DISPATCH(flip_fn, flip_stub);
36
+ DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
37
+ DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
38
+ DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
39
+ DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
40
+
41
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/ExpandUtils.h>
3
+ #include <ATen/native/CanUse32BitIndexMath.h>
4
+ #include <ATen/native/TensorIterator.h>
5
+ #include <ATen/core/IListRef.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ namespace at::native {
9
+
10
+ [[noreturn]]
11
+ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
12
+ TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
13
+ " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
14
+ }
15
+
16
+
17
+ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
18
+ // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
19
+ std::vector<Tensor> result;
20
+ for (const auto& index_opt : indices) {
21
+ if (!index_opt.has_value()) {
22
+ result.emplace_back();
23
+ } else {
24
+ const auto& index = *index_opt;
25
+ if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
26
+ if (index.scalar_type() == kByte) {
27
+ TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
28
+ " please use a dtype torch.bool instead.");
29
+ }
30
+ // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
31
+ // corresponding dimensions in self
32
+ for (const auto j : c10::irange(index.dim())) {
33
+ int64_t srcIdx = static_cast<int64_t>(result.size() + j);
34
+ if (index.size(j) != self.size(srcIdx)) {
35
+ invalid_mask(self, srcIdx, index, j);
36
+ }
37
+ }
38
+ // Replace with nonzeros
39
+ auto nonzero = index.nonzero();
40
+ for (const auto j : c10::irange(index.dim())) {
41
+ result.emplace_back(nonzero.select(1, j));
42
+ }
43
+ } else {
44
+ result.emplace_back(index);
45
+ }
46
+ }
47
+ }
48
+ return result;
49
+ }
50
+
51
+ static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
52
+ for (const auto& tensor : indices) {
53
+ if (tensor.has_value() && tensor->defined()) {
54
+ auto scalarType = tensor->scalar_type();
55
+ if (allow_int) {
56
+ if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
57
+ TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
58
+ }
59
+ } else {
60
+ if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
61
+ TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
62
+ }
63
+ }
64
+ }
65
+ }
66
+ }
67
+
68
+ inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
69
+ torch::List<std::optional<Tensor>> result;
70
+ result.reserve(list.size());
71
+ for (const Tensor& a : list) {
72
+ result.push_back(a);
73
+ }
74
+ return result;
75
+ }
76
+
77
+ inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
78
+ torch::List<std::optional<Tensor>> result;
79
+ result.reserve(list.size());
80
+ for (const IValue& a : list) {
81
+ result.push_back(a.isTensor() ? std::optional<Tensor>(a.toTensor()) : std::optional<Tensor>());
82
+ }
83
+ return result;
84
+ }
85
+
86
+ static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
87
+ // true if all the non-null tensors are adjacent
88
+ auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
89
+ auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
90
+ auto start = std::find_if(tl.begin(), tl.end(), isDefined);
91
+ auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
92
+ auto it = std::find_if(start, stop.base(), isNull);
93
+ return it == stop.base();
94
+ }
95
+
96
+
97
+ // Transposes the tensor and indices together so that all the non-null indices
98
+ // index the first k dimensions of the tensor. Returns the transposed tensor
99
+ // and the reordered indices. For example:
100
+ // transposeToFront(tensor, {nullptr, a, nullptr, b})
101
+ // returns
102
+ // tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
103
+ static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
104
+ transposeToFront(const Tensor& self, TensorList indices) {
105
+ std::vector<int64_t> dims;
106
+ std::vector<Tensor> transposedIndices;
107
+ dims.reserve(self.dim());
108
+ for (const auto i : c10::irange(self.dim())) {
109
+ if (indices[i].defined()) {
110
+ dims.push_back(i);
111
+ transposedIndices.emplace_back(indices[i]);
112
+ }
113
+ }
114
+ for (const auto i : c10::irange(self.dim())) {
115
+ if (!indices[i].defined()) {
116
+ dims.push_back(i);
117
+ transposedIndices.emplace_back();
118
+ }
119
+ }
120
+ return std::make_tuple(self.permute(dims), std::move(transposedIndices));
121
+ }
122
+
123
+ inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
124
+ transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
125
+ std::vector<int64_t> dims;
126
+ std::vector<int64_t> invPerm;
127
+ std::vector<Tensor> transposedIndices;
128
+ dims.reserve(self.dim());
129
+ invPerm.resize(self.dim());
130
+ for (const auto i : c10::irange(self.dim())) {
131
+ if (indices[i].defined()) {
132
+ dims.push_back(i);
133
+ transposedIndices.emplace_back(indices[i]);
134
+ }
135
+ }
136
+ for (const auto i : c10::irange(self.dim())) {
137
+ if (!indices[i].defined()) {
138
+ dims.push_back(i);
139
+ transposedIndices.emplace_back();
140
+ }
141
+ }
142
+ for (const auto i : c10::irange(self.dim())) {
143
+ invPerm[dims[i]] = i;
144
+ }
145
+ return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
146
+ }
147
+
148
+ struct AdvancedIndex {
149
+ AdvancedIndex(const Tensor& src, TensorList indices);
150
+
151
+ Tensor src;
152
+ std::vector<Tensor> indices;
153
+ DimVector indexed_sizes;
154
+ DimVector indexed_strides;
155
+ int64_t dims_before;
156
+ int64_t dims_after;
157
+ };
158
+
159
+
160
+ } //namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/OpMathType.h>
5
+ #include <ATen/TensorIterator.h>
6
+ #include <c10/core/Scalar.h>
7
+
8
+ namespace at::native {
9
+
10
+ template <typename scalar_t>
11
+ C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
12
+ return std::abs(weight) < scalar_t(0.5);
13
+ }
14
+ template <typename scalar_t>
15
+ C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
16
+ // Avoid the sqrt in abs(weight)
17
+ return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
18
+ }
19
+
20
+ template <typename scalar_t, typename weight_t>
21
+ C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
22
+ using opmath_t = at::opmath_type<scalar_t>;
23
+ using opmath_weight_t = at::opmath_type<weight_t>;
24
+
25
+ opmath_t self = self_;
26
+ opmath_t end = end_;
27
+ opmath_weight_t weight = weight_;
28
+
29
+ // Conditional for better numeric. This has been discussed in
30
+ // https://github.com/pytorch/pytorch/pull/18871
31
+ return is_lerp_weight_small(weight)
32
+ ? self + weight * (end - self)
33
+ : end - (end - self) * (opmath_t(1) - weight);
34
+ }
35
+
36
+ using lerp_fn_scalar = void (*)(
37
+ at::TensorIteratorBase& iter,
38
+ const Scalar& weight);
39
+
40
+ using lerp_fn_tensor = void (*)(
41
+ at::TensorIteratorBase& iter);
42
+
43
+ DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
44
+ DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
45
+
46
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace c10 {
6
+ class Scalar;
7
+ }
8
+
9
+ namespace at {
10
+ struct TensorIterator;
11
+ }
12
+
13
+ namespace at::native {
14
+
15
+ using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
16
+ DECLARE_DISPATCH(addr_fn, addr_stub);
17
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+ #include <c10/util/irange.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/strides.h>
7
+ #include <ATen/core/Tensor.h>
8
+ #include <ATen/ExpandUtils.h>
9
+ #include <ATen/TensorUtils.h>
10
+ #include <ATen/native/TensorIterator.h>
11
+ #include <ATen/native/TransposeType.h>
12
+ #include <limits>
13
+ #include <type_traits>
14
+ #include <sstream>
15
+ #include <cstring>
16
+ #include <cctype>
17
+
18
+ #ifndef AT_PER_OPERATOR_HEADERS
19
+ #include <ATen/Functions.h>
20
+ #else
21
+ #include <ATen/ops/arange.h>
22
+ #include <ATen/ops/empty.h>
23
+ #include <ATen/ops/empty_like.h>
24
+ #include <ATen/ops/empty_strided.h>
25
+ #include <ATen/ops/zeros.h>
26
+ #endif
27
+
28
+ namespace at::native {
29
+
30
+ inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
31
+ if (tensor.is_conj()) {
32
+ return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
33
+ } else {
34
+ return c10::MaybeOwned<Tensor>::borrowed(tensor);
35
+ }
36
+ }
37
+
38
+ inline DimVector batched_matrix_contiguous_strides(
39
+ const IntArrayRef sizes,
40
+ const bool f_contig = false) {
41
+ // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
42
+ // and C-contiguous matrices
43
+ auto strides = c10::contiguous_strides(sizes);
44
+ auto dim = strides.size();
45
+
46
+ if (f_contig && dim >= 2) {
47
+ // Fix the strides of the last two dimensions, so that we return
48
+ // C-contiguous batches of F-contiguous matrices.
49
+ strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
50
+ strides[dim - 2] = 1;
51
+ }
52
+ return strides;
53
+ }
54
+
55
+ /*
56
+ * Clones a Tensor so that the following conditions hold:
57
+ * If we think of a Tensor of having size (B, M, N), where B is any number
58
+ * of batch dimensions, then:
59
+ * - Each (M, N) matrix is in column major form
60
+ * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
61
+ * Then when laid out in memory, the M by N matrix starting at
62
+ * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
63
+ * matrix starting at Q.data_ptr()[B * M' * N'].
64
+ */
65
+ inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
66
+ // If src is already in batched column major format, then
67
+ // this will be efficient (no reordering of the data will occur)
68
+ // because the first transpose will make the tensor contiguous,
69
+ // and cloning a contiguous tensor is fast.
70
+ auto result = src.mT().clone(at::MemoryFormat::Contiguous);
71
+ result.transpose_(-2, -1);
72
+ return result;
73
+ }
74
+
75
+ /*
76
+ * contig chooses between C-contig (true) and F-contig (false)
77
+ */
78
+ inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
79
+ return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
80
+ : c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
81
+ : cloneBatchedColumnMajor(clone));
82
+ }
83
+
84
+ /*
85
+ * This method is designed to be a faster alternative to
86
+ * `cloneBatchedColumnMajor` with some additional features,
87
+ * namely:
88
+ * 1. It uses `copy` instead of `clone` which could be much faster.
89
+ * 2. `nrows` parameter used to create inputs with the number of rows larger
90
+ * than the original input, which is required for some LAPACK/MAGMA methods.
91
+ * 3. `desired_batch_size` is used to create copies with the batch size
92
+ * which is either the original batch size of the input, or its larger
93
+ * broadcasted shape.
94
+ */
95
+ inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
96
+ at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
97
+ nrows = (nrows == -1) ? src.size(-2) : nrows;
98
+ auto copy_sizes = desired_batch_sizes.has_value()
99
+ ? desired_batch_sizes.value().vec()
100
+ : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
101
+ copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
102
+ const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
103
+ auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
104
+ copy.narrow(-2, 0, src.size(-2)).copy_(src);
105
+ return copy;
106
+ }
107
+
108
+ /*
109
+ * Given batches of matrices with arbitrary batch dim,
110
+ * computes the number of batches.
111
+ */
112
+ inline int64_t batchCount(const Tensor& batched_matrices) {
113
+ int64_t result = 1;
114
+ for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
115
+ result *= batched_matrices.size(i);
116
+ }
117
+ return result;
118
+ }
119
+
120
+ // Computes the number of elements of a matrix in a batched matrix tensor
121
+ inline int64_t matrixStride(const Tensor& batched_matrices) {
122
+ return batched_matrices.size(-1) * batched_matrices.size(-2);
123
+ }
124
+
125
+ // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
126
+ inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
127
+ TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
128
+ }
129
+ inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
130
+ checkIsMatrix(self, f_name, arg_name);
131
+ TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
132
+ f_name,
133
+ ": ", arg_name, " must be batches of square matrices, "
134
+ "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
135
+ }
136
+
137
+ inline void checkInputsSolver(const Tensor& A,
138
+ const Tensor& B,
139
+ const bool left,
140
+ const char* const f_name) {
141
+ squareCheckInputs(A, f_name, "A");
142
+ checkIsMatrix(B, f_name, "B");
143
+ TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
144
+ f_name, ": Incompatible shapes of A and B for the equation ",
145
+ left ? "AX = B" : "XA = B",
146
+ " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
147
+ }
148
+
149
+ inline bool is_row_or_column_contiguous(const Tensor& t) {
150
+ // This could be made more general, similar to how it's checked in matmul, which would allow to
151
+ // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
152
+ // We choose to be conservative for simplicity
153
+ return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
154
+ }
155
+
156
+ inline TransposeType to_transpose_type(const bool contig, const bool conj) {
157
+ if (conj) {
158
+ if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
159
+ else { return TransposeType::ConjTranspose; }
160
+ } else {
161
+ if (contig) { return TransposeType::NoTranspose; }
162
+ else { return TransposeType::Transpose; }
163
+ }
164
+ }
165
+
166
+
167
+ // This function is designed to be used with linear algebra methods that minimize
168
+ // L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
169
+ // or the L2 norm (`lstsq`).
170
+ // It is expected that `a` and `b` are contiguous tensors of column-major matrices
171
+ // (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
172
+ // with the following additional properties:
173
+ //
174
+ // 1. a.dim() == b.dim()
175
+ // 2. a.shape[:-2] broadcasts over b.shape[:-2]
176
+ // 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
177
+ //
178
+ // MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
179
+ // is to be memory efficient, which means that if there exists an index i such that
180
+ // a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
181
+ // then instead of materializing copies of `a` in the broadcasted shape, we keep
182
+ // a buffer copy of `a` along with flags that check whether specific batch dimension
183
+ // indices for `a` were already accessed. If they were, we copy the data from the buffer
184
+ // into `a`. The number of copies does not exceed
185
+ // prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
186
+ // and this value is attained by tensors with non-empty batch dimensions.
187
+ //
188
+ // func_t `f` is a callable that is being supplied with
189
+ // scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
190
+ // a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
191
+ // and a_linear_batch_idx is an index in the 3d representation which corresponds to
192
+ // the memory a_working_ptr points to, in other words:
193
+ // a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
194
+ // a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
195
+ // its rank or singular values (see linalg_lstsq).
196
+ template<typename scalar_t, typename func_t>
197
+ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
198
+ IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
199
+ IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
200
+
201
+ auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
202
+ auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
203
+
204
+ TensorIterator iter = TensorIteratorConfig()
205
+ .set_check_mem_overlap(false)
206
+ .check_all_same_dtype(false)
207
+ .resize_outputs(false)
208
+ .add_output(b_linear_batch_idx)
209
+ .add_input(a_linear_batch_idx)
210
+ .build();
211
+
212
+ auto m = a.size(-2);
213
+ auto n = a.size(-1);
214
+ auto a_3d = a.view({batchCount(a), m, n});
215
+ auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
216
+
217
+ auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
218
+ Tensor a_buffer, a_was_accessed, a_buffer_3d;
219
+ std::function<void(int64_t)> check_if_copy_needed_for_a
220
+ = [](int64_t /*a_curr_linear_batch_idx*/){};
221
+ if (a_broadcasts_over_b) {
222
+ a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
223
+ .copy_(a);
224
+ a_was_accessed = at::zeros(batchCount(a), at::kBool);
225
+ a_buffer_3d = a_buffer.view({batchCount(a), m, n});
226
+ check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
227
+ auto* a_was_accessed_flag = a_was_accessed
228
+ .select(0, a_curr_linear_batch_idx)
229
+ .data_ptr<bool>();
230
+ if (!(*a_was_accessed_flag)) {
231
+ *a_was_accessed_flag = true;
232
+ }
233
+ else {
234
+ a_3d.select(0, a_curr_linear_batch_idx)
235
+ .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
236
+ }
237
+ };
238
+ }
239
+
240
+ auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
241
+ auto* b_batch_idx_ptr = data[0];
242
+ auto* a_batch_idx_ptr = data[1];
243
+
244
+ for (const auto elem C10_UNUSED : c10::irange(nelems)) {
245
+ auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
246
+ auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
247
+
248
+ check_if_copy_needed_for_a(a_curr_linear_batch_idx);
249
+
250
+ auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
251
+ .data_ptr<scalar_t>();
252
+ auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
253
+ .data_ptr<scalar_t>();
254
+ f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
255
+
256
+ b_batch_idx_ptr += strides[0];
257
+ a_batch_idx_ptr += strides[1];
258
+ }
259
+ };
260
+ iter.serial_for_each(loop, {0, batchCount(b)});
261
+ }
262
+
263
+ // Returns the epsilon value for floating types except half
264
+ inline double _get_epsilon(const ScalarType& sc_type) {
265
+ switch (sc_type) {
266
+ case at::ScalarType::Float:
267
+ return static_cast<double>(std::numeric_limits<float>::epsilon());
268
+ case at::ScalarType::Double:
269
+ return std::numeric_limits<double>::epsilon();
270
+ default:
271
+ AT_ERROR("This function doesn't handle types other than float and double");
272
+ }
273
+ }
274
+
275
+ // Validates input shapes and devices
276
+ // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
277
+ inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
278
+ TORCH_CHECK(self.device() == A.device(),
279
+ "Expected b and A to be on the same device, but found b on ",
280
+ self.device(), " and A on ", A.device(), " instead.");
281
+
282
+ TORCH_CHECK(self.scalar_type() == A.scalar_type(),
283
+ "Expected b and A to have the same dtype, but found b of type ",
284
+ self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
285
+
286
+ TORCH_CHECK(A.size(-1) == A.size(-2),
287
+ "A must be batches of square matrices, "
288
+ "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
289
+
290
+ TORCH_CHECK(A.size(-1) == self.size(-2),
291
+ "Incompatible matrix sizes for ", name, ": each A "
292
+ "matrix is ", A.size(-1), " by ", A.size(-1),
293
+ " but each b matrix is ", self.size(-2), " by ", self.size(-1));
294
+ }
295
+
296
+ inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
297
+ auto dtype = t.scalar_type();
298
+ TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
299
+ f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
300
+ if (!allow_low_precision_dtypes) {
301
+ TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
302
+ f_name, ": Low precision dtypes not supported. Got ", dtype);
303
+ }
304
+ }
305
+
306
+
307
+ // Checks if all the Tensors in a TensorList are of the same dimensions
308
+ inline void checkAllSameDim(TensorList tensors, int64_t dim) {
309
+ for (auto &t : tensors) {
310
+ TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
311
+ }
312
+ }
313
+
314
+ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
315
+ // broadcast the batch dimensions of arg1 and arg2.
316
+ IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
317
+ IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
318
+ std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
319
+
320
+ std::vector<int64_t> arg1_expand_size({expand_batch_portion});
321
+ arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
322
+
323
+ std::vector<int64_t> arg2_expand_size({expand_batch_portion});
324
+ arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
325
+ return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
326
+ }
327
+
328
+ inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
329
+ // If there's no name we assume we don't want to check the errors
330
+ if (name != nullptr) {
331
+ linearSolveCheckInputs(arg1, arg2, name);
332
+ }
333
+
334
+ auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
335
+
336
+ auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
337
+ auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
338
+ return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
339
+ }
340
+
341
+ inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
342
+ IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
343
+ IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
344
+ auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
345
+ return broadcasted_batch_sizes;
346
+ }
347
+
348
+ // Return a permutation with the given axes moved to the end.
349
+ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
350
+ const std::vector<int64_t> a = axes.vec();
351
+ const int64_t ndim = self.ndimension();
352
+ std::vector<int64_t> perm;
353
+
354
+ for (const auto i : c10::irange(ndim)) {
355
+ auto it = std::find(a.begin(), a.end(), i);
356
+ if (it == a.end()) {
357
+ perm.push_back(i);
358
+ }
359
+ }
360
+ for (auto i : a) {
361
+ perm.push_back(i);
362
+ }
363
+
364
+ TORCH_CHECK((int64_t)perm.size() == ndim,
365
+ "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
366
+
367
+ return self.permute(perm);
368
+ }
369
+
370
+ // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
371
+ inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
372
+ bool compute_q;
373
+ bool reduced;
374
+ if (mode == "reduced") {
375
+ compute_q = true;
376
+ reduced = true;
377
+ } else if (mode == "complete") {
378
+ compute_q = true;
379
+ reduced = false;
380
+ } else if (mode == "r") {
381
+ compute_q = false;
382
+ reduced = true; // this is actually irrelevant in this mode
383
+ } else {
384
+ TORCH_CHECK(false, "qr received unrecognized mode '", mode,
385
+ "' but expected one of 'reduced' (default), 'r', or 'complete'");
386
+ }
387
+ return std::make_tuple(compute_q, reduced);
388
+ }
389
+
390
+ // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
391
+ inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
392
+ const Tensor& input,
393
+ bool reduced) {
394
+ int64_t m = input.size(-2), n = input.size(-1);
395
+ int64_t n_columns_q;
396
+
397
+ // We need to compute the required size of Q based on the `reduced` option
398
+ DimVector q_sizes(input.sizes());
399
+ if (!reduced && m > n) {
400
+ q_sizes[input.dim() - 1] = m;
401
+ n_columns_q = m;
402
+ } else {
403
+ q_sizes[input.dim() - 1] = n;
404
+ n_columns_q = std::min(m, n);
405
+ }
406
+ auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
407
+ return std::make_tuple(q_sizes, q_strides, n_columns_q);
408
+ }
409
+
410
+ inline bool svd_uses_cusolver(const Tensor& A) {
411
+ // if cusolver is available, it is used unconditionally
412
+ return A.is_cuda()
413
+ && at::globalContext().hasCuSOLVER()
414
+ && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
415
+ }
416
+
417
+
418
+ // Function used instead of .to so that the original strides are retained
419
+ // .to doesn't retain strides and make the output tensor contiguous
420
+ inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
421
+ auto strided_to = at::empty_strided(original_tensor.sizes(),
422
+ original_tensor.strides(),
423
+ options);
424
+ strided_to.copy_(original_tensor);
425
+ return strided_to;
426
+ }
427
+
428
+ // Creates a dimension permutation array that can be given to `at::permute()`, which will shift
429
+ // the two specified dimensions to the end of a tensor, without changing the order of
430
+ // the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
431
+ // placed just to the left of it.
432
+ //
433
+ // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
434
+ // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
435
+ // be `vec(0, 2, 1, 3)`.
436
+ inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
437
+ TORCH_CHECK(
438
+ (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
439
+ "duplicate or invalid dimensions");
440
+ std::vector<int64_t> permutation(ndim);
441
+ int64_t cur_permuted_dim = 0;
442
+ for (const auto dim_ind : c10::irange(ndim)) {
443
+ if ((dim_ind != dim0) && (dim_ind != dim1)) {
444
+ permutation[cur_permuted_dim++] = dim_ind;
445
+ }
446
+ }
447
+ permutation[cur_permuted_dim++] = dim0;
448
+ permutation[cur_permuted_dim] = dim1;
449
+ return permutation;
450
+ }
451
+
452
+ // Creates a dimension permutation array that can be given to `at::permute()`, which
453
+ // will reverse a given permutation.
454
+ // The reverse permutation array is created by swapping the indices and their
455
+ // associated values from the given permutation array.
456
+ inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
457
+ int64_t ndim = permutation.size();
458
+ std::vector<int64_t> reverse_permutation(ndim);
459
+ for (const auto dim_ind : c10::irange(ndim)) {
460
+ reverse_permutation[permutation[dim_ind]] = dim_ind;
461
+ }
462
+ return reverse_permutation;
463
+ }
464
+
465
+ // Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
466
+ // See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
467
+ inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
468
+ auto mn = std::min(m, n);
469
+ auto mx = std::max(m, n);
470
+ if (jobz == 'N') {
471
+ #ifdef __APPLE__
472
+ // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
473
+ return 7 * mn;
474
+ #else
475
+ // These setting is valid for on LAPACK 3.6+
476
+ return 5 * mn;
477
+ #endif
478
+ }
479
+ if (mx > 10 * mn) {
480
+ return 5 * mn * mn + 5 * mn;
481
+ }
482
+ return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
483
+ }
484
+
485
+ // This function checks whether the uplo argument input is valid
486
+ // Allowed strings are "u", "U", "l", "L"
487
+ inline void checkUplo(const c10::string_view uplo) {
488
+ // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
489
+ char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
490
+ TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
491
+ "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
492
+ }
493
+
494
+ inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
495
+ TORCH_CHECK(
496
+ result.device() == input.device(),
497
+ fn_name,
498
+ ": Expected ", result_name, " and input tensors to be on the same device, but got ",
499
+ result_name, " on ", result.device(), " and input on ", input.device());
500
+ }
501
+
502
+ // Check the dtype of result and input tensors (for _out variants).
503
+ // Most linear algebra functions have the same dtype for input and output
504
+ // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
505
+ // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
506
+ // c10::canCast is used for checking the "safe copy" dtype requirements.
507
+ inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
508
+ bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
509
+ TORCH_CHECK(
510
+ can_cast,
511
+ fn_name,
512
+ ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
513
+ result_name, " with dtype ", result.scalar_type());
514
+ }
515
+
516
+ // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
517
+ inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
518
+ bool can_cast = c10::canCast(result_type, out_type);
519
+ TORCH_CHECK(
520
+ can_cast,
521
+ fn_name,
522
+ ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
523
+ out_name, " with dtype ", out_type);
524
+ }
525
+
526
+ inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
527
+ TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
528
+ f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
529
+ }
530
+
531
+ /*
532
+ Two types of 'other' tensors are supported when solving
533
+ a system of linear equations matmul(input, x) = other:
534
+ * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
535
+ * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
536
+ The original torch.solve supported only the matrix case, while NumPy works for both cases.
537
+ For the batched input we need to be able to distinguish them.
538
+ Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
539
+ This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
540
+ */
541
+ inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
542
+ auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
543
+ bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
544
+ return vector_case;
545
+ }
546
+
547
+ /*
548
+ Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
549
+ */
550
+ inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
551
+ TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
552
+ return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
553
+ }
554
+
555
+ class BroadcastLinearIndices {
556
+ private:
557
+ Tensor linear_indices_;
558
+ bool is_broadcasting_;
559
+
560
+ public:
561
+ BroadcastLinearIndices(
562
+ int64_t numel,
563
+ IntArrayRef original_shape,
564
+ IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
565
+ // The assumption is that the broadcast_shape is a materialized broadcast
566
+ // shape of the original_shape. We need to compute the linear indices
567
+ // compatible with the original_shape to access the elements in the original
568
+ // tensor corresponding to the broadcast tensor.
569
+ if (is_broadcasting_) {
570
+ linear_indices_ =
571
+ get_linear_indices(numel, original_shape, broadcast_shape);
572
+ }
573
+ }
574
+ int64_t operator()(int64_t broadcast_linear_index) {
575
+ return is_broadcasting_
576
+ ? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
577
+ : broadcast_linear_index;
578
+ }
579
+ };
580
+
581
+ inline bool is_blas_compatible_column_major_order(const Tensor& input) {
582
+ IntArrayRef input_strides = input.strides();
583
+ IntArrayRef input_sizes = input.sizes();
584
+ auto ndim = input.dim();
585
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
586
+ if (ndim > 3) {
587
+ return input.transpose(-2, -1).is_contiguous();
588
+ }
589
+ auto leading_dimension = input_strides[ndim - 1];
590
+ auto rows = input_sizes[ndim - 2];
591
+ bool batch_stride_compatible = true;
592
+ if (ndim == 3) {
593
+ auto cols = input_sizes[ndim - 1];
594
+ batch_stride_compatible =
595
+ input_strides[ndim - 3] >= leading_dimension * cols;
596
+ }
597
+ return (input_strides[ndim - 2] == 1) &&
598
+ (leading_dimension >= std::max<int64_t>(1, rows)) &&
599
+ batch_stride_compatible;
600
+ }
601
+
602
+ inline bool is_blas_compatible_row_major_order(const Tensor& input) {
603
+ IntArrayRef input_strides = input.strides();
604
+ IntArrayRef input_sizes = input.sizes();
605
+ auto ndim = input.dim();
606
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
607
+ if (ndim > 3) {
608
+ return input.is_contiguous();
609
+ }
610
+ auto leading_dimension = input_strides[ndim - 2];
611
+ auto cols = input_sizes[ndim - 1];
612
+ bool batch_stride_compatible = true;
613
+ if (ndim == 3) {
614
+ auto rows = input_sizes[ndim - 2];
615
+ batch_stride_compatible =
616
+ input_strides[ndim - 3] >= leading_dimension * rows;
617
+ }
618
+ return (input_strides[ndim - 1] == 1) &&
619
+ (leading_dimension >= std::max<int64_t>(1, cols)) &&
620
+ batch_stride_compatible;
621
+ }
622
+
623
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/core/dispatch/Dispatcher.h>
3
+ #include <ATen/core/op_registration/op_registration.h>
4
+ #include <ATen/native/UnaryOps.h>
5
+ #include <ATen/native/Resize.h>
6
+ #include <c10/util/irange.h>
7
+ #include <torch/library.h>
8
+
9
+ #ifndef AT_PER_OPERATOR_HEADERS
10
+ #include <ATen/Functions.h>
11
+ #else
12
+ #include <ATen/ops/clone.h>
13
+
14
+ #include <utility>
15
+ #endif
16
+
17
+ namespace at::native {
18
+ // This fallback should only be used for operations that are self inverse and have a corresponding tensor
19
+ // bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
20
+ // Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
21
+ // Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
22
+
23
+ // NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
24
+ struct MathOpFallback {
25
+ MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
26
+ virtual bool is_bit_set(const Tensor&) = 0;
27
+ void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
28
+ /*
29
+ Situations to handle:
30
+ 1. Out-of-place operation. Easy: materialize all inputs and
31
+ call it a day.
32
+ 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
33
+ Materialize other inputs as in (1).
34
+ 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
35
+ Materialize other inputs as in (1).
36
+
37
+ It is important to be able to tell if we READ from an argument and if we
38
+ WRITE to an argument. Conservative approach is to assume that we always
39
+ READ from an argument, but in out= operations you can skip
40
+ conjugating inputs on entry that never get used. In the current schema we
41
+ can't easily tell if the operation is in in-place or out= operation.
42
+
43
+ Note:
44
+ 1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
45
+ 2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
46
+ correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
47
+
48
+ If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
49
+ with these mutable inputs would read into wrong values in the following cases:
50
+ 1. Non mutable inputs have their math bit set to false.
51
+ 2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
52
+ with one or more mutable arg(s)) are cloned.
53
+ At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
54
+ */
55
+ const auto& arguments = op.schema().arguments();
56
+ const auto num_arguments = arguments.size();
57
+ const auto stack_start = stack->size() - num_arguments;
58
+
59
+ std::optional<bool> is_write;
60
+ for (const auto i : c10::irange(num_arguments)) {
61
+ // Three possible states:
62
+ // 1. alias_info has no value --> out-of-place operation
63
+ // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
64
+ // 3. alias_info does have a value, alias_info->is_write=False --> view operation
65
+ const AliasInfo* alias_info = arguments[i].alias_info();
66
+ if (alias_info != nullptr) {
67
+ if (is_write.has_value()) {
68
+ TORCH_CHECK(*is_write == alias_info->isWrite(),
69
+ "Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
70
+ op_name, " fallback doesn't work for operators with a mix "
71
+ "mutable and non-mutable inputs that alias with outputs, "
72
+ "this must be implemented manually. "
73
+ "If you got this error on a core op, please report a bug to PyTorch.");
74
+ } else {
75
+ is_write = alias_info->isWrite();
76
+ }
77
+ }
78
+ }
79
+
80
+ if (is_write.has_value() && !*is_write) {
81
+ // We assume that view operators automatically handle the math bit
82
+ // correctly by propagating the dispatch key in key_set.
83
+ // This is not necessarily always right, so you should test these cases.
84
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
85
+ return;
86
+ }
87
+
88
+ // Mutable inputs with math bit set to True and their clones
89
+ std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
90
+ for (const auto i : c10::irange(num_arguments)) {
91
+ auto& ivalue = (*stack)[stack_start + i];
92
+ if (!(ivalue.isTensor() || ivalue.isTensorList())) {
93
+ continue;
94
+ }
95
+ const auto& argument = arguments[i];
96
+ bool mut_arg = false;
97
+ if (argument.alias_info()) {
98
+ // Was already tested by is_write loop above
99
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
100
+ mut_arg = true;
101
+ }
102
+ if (ivalue.isTensor()) {
103
+ if (!is_bit_set(ivalue.toTensor())) {
104
+ continue;
105
+ }
106
+ auto tensor = std::move(ivalue).toTensor();
107
+ auto resolved_tensor = at::clone(tensor);
108
+ if (mut_arg) {
109
+ TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
110
+ op_name, "bit set to true.");
111
+ mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
112
+ }
113
+ (*stack)[stack_start + i] = std::move(resolved_tensor);
114
+ } else if (ivalue.isTensorList()) {
115
+ auto tensors = std::move(ivalue).toTensorList();
116
+ for(const auto j : c10::irange(tensors.size())) {
117
+ const auto& tensor = tensors[j];
118
+ if (!is_bit_set(tensor)) {
119
+ continue;
120
+ }
121
+ TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
122
+ op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
123
+ op.schema().name());
124
+ tensors[j] = at::clone(tensor);
125
+ }
126
+ (*stack)[stack_start + i] = std::move(tensors);
127
+ }
128
+ }
129
+
130
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
131
+
132
+ TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
133
+
134
+ for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
135
+ auto& mutable_input = mut_tensors.first;
136
+ auto& cloned_mutable_input = mut_tensors.second;
137
+ auto& ivalue = (*stack)[stack_start];
138
+ auto returned_output = std::move(ivalue).toTensor();
139
+
140
+ // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
141
+ TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
142
+
143
+ // necessary for out= arg
144
+ at::native::resize_output(mutable_input, returned_output.sizes());
145
+
146
+ mutable_input.copy_(returned_output);
147
+ (*stack)[stack_start] = std::move(mutable_input);
148
+ }
149
+ }
150
+
151
+ virtual ~MathOpFallback() = default;
152
+
153
+ DispatchKey key;
154
+ string op_name;
155
+ };
156
+
157
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Parallel.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/Pool.h>
7
+
8
+ namespace at::native {
9
+
10
+ inline void check_max_pool1d(
11
+ const Tensor& self,
12
+ IntArrayRef kernel_size,
13
+ IntArrayRef stride,
14
+ IntArrayRef padding,
15
+ IntArrayRef dilation,
16
+ bool ceil_mode) {
17
+
18
+ TORCH_CHECK(
19
+ self.dim() == 2 || self.dim() == 3,
20
+ "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
21
+ TORCH_CHECK(
22
+ kernel_size.size() == 1,
23
+ "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
24
+ kernel_size.size());
25
+ TORCH_CHECK(
26
+ stride.empty() || stride.size() == 1,
27
+ "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
28
+ stride.size());
29
+ TORCH_CHECK(
30
+ padding.size() == 1,
31
+ "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
32
+ padding.size());
33
+ TORCH_CHECK(
34
+ dilation.size() == 1,
35
+ "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
36
+ dilation.size());
37
+
38
+ // If stride=None then set it to kernel_size
39
+ if (stride.empty()) {
40
+ stride = kernel_size;
41
+ }
42
+
43
+ TORCH_CHECK(
44
+ kernel_size[0] > 0,
45
+ "max_pool1d() kernel_size must be greater than zero, but got ",
46
+ kernel_size[0]);
47
+ TORCH_CHECK(
48
+ stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
49
+ TORCH_CHECK(
50
+ padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
51
+ TORCH_CHECK(
52
+ padding[0] <= kernel_size[0] / 2,
53
+ "max_pool1d() padding should be at most half of kernel size, but got padding=",
54
+ padding[0],
55
+ " and kernel_size=",
56
+ kernel_size[0]);
57
+ TORCH_CHECK(
58
+ dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
59
+
60
+ const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
61
+ TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
62
+ }
63
+
64
+ // TODO(Heitor) Template by dimension
65
+ struct PoolingParams1D {
66
+ int64_t NB; // Number of batches
67
+ int64_t NC; // Number of channels
68
+ int64_t IW; // Input width
69
+ int64_t OW; // Output width
70
+ int64_t KW; // Kernel width
71
+ int64_t SJ; // Column stride
72
+ int64_t PJ; // Column padding
73
+ int64_t DJ; // Column dilation
74
+
75
+ // Return index of input element for the given kernel and output index
76
+ inline int64_t index(int64_t kj, int64_t oj) const {
77
+ return oj * SJ + kj * DJ - PJ;
78
+ }
79
+
80
+ // Return index of first output within bounds for this kernel index
81
+ inline int64_t valid_output_start(int64_t kj) const {
82
+ int64_t ij = index(kj, 0);;
83
+ return ij < 0 ? at::divup(-ij, SJ) : 0;
84
+ }
85
+
86
+ // Return index one past last output within bounds for this kernel index
87
+ inline int64_t valid_output_end(int64_t kj) const {
88
+ int64_t ij = index(kj, OW - 1);
89
+ return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
90
+ }
91
+ };
92
+
93
+ using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
94
+
95
+ DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
96
+
97
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at::native {
7
+ // This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
8
+ // However, in certain cases (such as static runtime), we call the native versions of the ops directly.
9
+ // In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
10
+ TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
11
+ TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
12
+ TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, std::optional<at::ScalarType> dtype=std::nullopt, std::optional<at::Layout> layout=std::nullopt, std::optional<at::Device> device=std::nullopt, std::optional<bool> pin_memory=std::nullopt, std::optional<bool> is_coalesced=std::nullopt);
13
+ TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
14
+ TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
15
+ // The below ops don't get a duplicated C++ implementation.
16
+ // They are backward ops, which make them very unlikely to be called directly
17
+ // by external code (at::native::trace_backward).
18
+ // They get their own declaration for BC purposes however.
19
+ TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
20
+ TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
21
+ TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
22
+ TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
23
+ TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
24
+ TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
25
+ TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
26
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/native/DispatchStub.h>
2
+ #include <c10/core/Scalar.h>
3
+
4
+ namespace at {
5
+ struct TensorIterator;
6
+
7
+ namespace native {
8
+
9
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
10
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
11
+
12
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
12
+ using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
13
+ DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
14
+ DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
15
+
16
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Scalar.h>
4
+
5
+ namespace at::native {
6
+
7
+ enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
8
+
9
+ inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
10
+ if (reduce == "max" || reduce == "amax") {
11
+ return ReductionType::MAX;
12
+ } else if (reduce == "mean") {
13
+ return ReductionType::MEAN;
14
+ } else if (reduce == "min" || reduce == "amin") {
15
+ return ReductionType::MIN;
16
+ } else if (reduce == "sum") {
17
+ return ReductionType::SUM;
18
+ } else if (reduce == "prod") {
19
+ return ReductionType::PROD;
20
+ } else {
21
+ TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
22
+ }
23
+ }
24
+
25
+ // used for `scatter_reduce`, old options for BC.
26
+ inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
27
+ if (use_new_options) {
28
+ return get_reduction_enum(reduce);
29
+ } else {
30
+ if (reduce == "add") {
31
+ return ReductionType::SUM;
32
+ } else if (reduce == "multiply") {
33
+ return ReductionType::PROD;
34
+ } else {
35
+ TORCH_CHECK(false, "reduce argument must be either add or multiply.")
36
+ }
37
+ }
38
+ }
39
+
40
+ } // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/TensorOperators.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/Functions.h>
8
+ #else
9
+ #include <ATen/ops/empty.h>
10
+ #include <ATen/ops/empty_like.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ template <
16
+ typename index_t,
17
+ void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
18
+ static inline Tensor repeat_interleave_common(
19
+ const Tensor& repeats,
20
+ std::optional<int64_t> output_size) {
21
+ TORCH_CHECK(
22
+ repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
23
+ TORCH_CHECK(
24
+ repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
25
+ "repeats has to be Long or Int tensor");
26
+ if (repeats.size(0) == 0) {
27
+ return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
28
+ }
29
+ Tensor repeats_ = repeats.contiguous();
30
+ Tensor cumsum = repeats.cumsum(0);
31
+ int64_t total = 0;
32
+ if (output_size.has_value()) {
33
+ total = output_size.value();
34
+ } else {
35
+ total = cumsum[-1].item<int64_t>();
36
+ TORCH_CHECK(
37
+ (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
38
+ }
39
+
40
+ Tensor result = at::empty({total}, repeats.options());
41
+ const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
42
+ const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
43
+ index_t* result_ptr = result.data_ptr<index_t>();
44
+ compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
45
+ return result;
46
+ }
47
+
48
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/ResizeCommon.h>
5
+ #include <ATen/EmptyTensor.h>
6
+ #include <ATen/TensorUtils.h>
7
+
8
+ #include <c10/core/CPUAllocator.h>
9
+
10
+ #include <utility>
11
+
12
+
13
+ namespace at::native {
14
+
15
+ // TODO: make all operations that resize given outputs use this function
16
+ // for consistency and maintainability.
17
+ // Some operations like `cat` might not be able to make the use of
18
+ // resize_output directly. For more details to understand how it works in `cat`,
19
+ // see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
20
+ // Resizes outputs
21
+ // Functions accepting output tensors, like with the "out" kwarg, should
22
+ // call this function to handle resizing their output tensor.
23
+ // Issues a warning if the output tensor has one or more elements and
24
+ // needs resizing
25
+ // NOTE: In the future the warning will become an error
26
+ // Returns a bool saying whether or not the resize actually happened or not
27
+ TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
28
+ // WARNING: Do NOT call this directly. If you are resizing an output and want
29
+ // to support dynamic shapes call at::resize__symint and resize_output_check_symint.
30
+ // For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
31
+ TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
32
+
33
+ // Utility for resize_output
34
+ // Returns a bool saying resize should happen or not and
35
+ // raises a warning if resizing for one or more elements
36
+ TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
37
+ TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
38
+
39
+ TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
40
+ TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
41
+ TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes);
42
+
43
+ inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
44
+ // It does not make sense to try to resize a storage
45
+ // to hold 0 elements, and this can break
46
+ // if storage_offset is positive but
47
+ // new_size is 0, so just bail in that case
48
+ // (same comment is in cuda/Resize.h)
49
+ if (self->numel() == 0) {
50
+ return;
51
+ }
52
+
53
+ const Storage& storage = self->unsafe_storage();
54
+ if (!storage) {
55
+ auto new_storage = c10::make_intrusive<StorageImpl>(
56
+ StorageImpl::use_byte_size_t(),
57
+ new_size_bytes,
58
+ c10::GetCPUAllocator(),
59
+ true);
60
+ self->set_storage_keep_dtype(std::move(new_storage));
61
+ } else if (new_size_bytes > storage.nbytes()) {
62
+ resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
63
+ }
64
+ }
65
+
66
+ TORCH_API TensorImpl* resize_impl_cpu_(
67
+ TensorImpl* self,
68
+ IntArrayRef size,
69
+ at::OptionalIntArrayRef stride,
70
+ bool resize_storage = true);
71
+
72
+ template <typename T>
73
+ T maybe_convert_symint(c10::SymInt) = delete;
74
+
75
+ template <>
76
+ inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
77
+
78
+ template <>
79
+ inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
80
+
81
+ template <typename T>
82
+ inline void checkInBoundsForStorage(
83
+ ArrayRef<T> size,
84
+ ArrayRef<T> stride,
85
+ T storage_offset,
86
+ const caffe2::TypeMeta& data_type,
87
+ const Storage& new_storage) {
88
+ T storage_size_bytes =
89
+ at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
90
+ T storage_offset_bytes = storage_offset * data_type.itemsize();
91
+ if (storage_size_bytes == 0) {
92
+ // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
93
+ return;
94
+ }
95
+ T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
96
+ TORCH_CHECK(
97
+ storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
98
+ "setStorage: sizes ",
99
+ size,
100
+ ", strides ",
101
+ stride,
102
+ ","
103
+ " storage offset ",
104
+ storage_offset,
105
+ ", and itemsize ",
106
+ data_type.itemsize(),
107
+ " requiring a storage size of ",
108
+ storage_size_bytes + storage_offset_bytes,
109
+ " are out of bounds for storage of size ",
110
+ new_storage_size_bytes);
111
+ }
112
+
113
+ template <typename T>
114
+ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
115
+ ArrayRef<T> size, ArrayRef<T> stride) {
116
+ // FIXME: stride should be optional
117
+ if (stride.data()) {
118
+ TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
119
+ ") and stride length (", stride.size(), ")");
120
+ }
121
+
122
+ #ifdef DEBUG
123
+ TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
124
+ #endif
125
+
126
+ // storage: note this can't be replaced with result.set_(storage) as the semantics of that
127
+ // function is to set the tensor size to be equal to the size of the storage.
128
+ if (!result.storage().is_alias_of(storage)) {
129
+ // Caffe2 might have tensors whose storages are null, but we
130
+ // don't allow it in PyTorch.
131
+ TORCH_INTERNAL_ASSERT(storage);
132
+ TORCH_INTERNAL_ASSERT(result.storage());
133
+
134
+ // We used to allow this, but this breaks device caching.
135
+ // Let's put an actual error message for this one.
136
+ TORCH_CHECK(result.storage().device() == storage.device(),
137
+ "Attempted to set the storage of a tensor on device \"", result.storage().device(),
138
+ "\" to a storage on different device \"", storage.device(),
139
+ "\". This is no longer allowed; the devices must match.");
140
+ result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
141
+ }
142
+
143
+ // storageOffset
144
+ TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
145
+ }
146
+
147
+ /**
148
+ * Set self's sizes, strides, and storage_offset.
149
+ * (size, stride, storage_offset) must be in bounds for self's storage.
150
+ */
151
+ template <typename T>
152
+ inline void setStrided(
153
+ const Tensor& self,
154
+ ArrayRef<T> size,
155
+ ArrayRef<T> stride,
156
+ T storage_offset) {
157
+ TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
158
+ for (const auto& val : stride) {
159
+ TORCH_CHECK(val >= 0,
160
+ "as_strided: Negative strides are not supported at the moment, "
161
+ "got strides: ", stride);
162
+ }
163
+
164
+ auto* self_ = self.unsafeGetTensorImpl();
165
+ checkInBoundsForStorage(
166
+ size, stride, storage_offset, self_->dtype(), self_->storage());
167
+
168
+ /* storage offset */
169
+ TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
170
+ self_->set_sizes_and_strides(size, stride, std::make_optional(storage_offset));
171
+ }
172
+
173
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/TensorFactories.h>
5
+ #include <ATen/NamedTensorUtils.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ #ifndef AT_PER_OPERATOR_HEADERS
9
+ #include <ATen/NativeFunctions.h>
10
+ #else
11
+ #include <ATen/ops/empty.h>
12
+ #endif
13
+
14
+ namespace at::native {
15
+
16
+ template <typename T>
17
+ inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
18
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
19
+ "storage_size_for(size, stride) requires that size and stride ",
20
+ "have the same size as a precondition.");
21
+ T storage_size = 1;
22
+ for (const auto dim : c10::irange(size.size())) {
23
+ if (size[dim] == 0) {
24
+ storage_size = 0;
25
+ break;
26
+ }
27
+ storage_size += (size[dim] - 1) * stride[dim];
28
+ }
29
+ return storage_size;
30
+ }
31
+
32
+ inline const Tensor& resize_named_tensor_(
33
+ const Tensor& self,
34
+ IntArrayRef size,
35
+ std::optional<MemoryFormat> optional_memory_format) {
36
+ TORCH_INTERNAL_ASSERT(self.has_names());
37
+ TORCH_CHECK(
38
+ self.sizes() == size,
39
+ "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
40
+ "Tensor",
41
+ self.names(),
42
+ " with size ",
43
+ self.sizes(),
44
+ " to ",
45
+ size,
46
+ "). This may be caused by passing a named tensor ",
47
+ "as an `out=` argument; please ensure that the sizes are the same. ");
48
+ TORCH_CHECK(
49
+ !optional_memory_format.has_value(),
50
+ "Unsupported memory format for named tensor resize ",
51
+ optional_memory_format.value());
52
+ return self;
53
+ }
54
+
55
+ // For deterministic output, fill new elements that were added after a storage
56
+ // resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
57
+ // before the resize happened.
58
+ inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
59
+ const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
60
+ int64_t new_storage_nbytes = storage.nbytes();
61
+ int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
62
+ int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
63
+ if (new_storage_numel > old_storage_numel) {
64
+ at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
65
+ tensor_view.set_(
66
+ storage,
67
+ /*storage_offset=*/old_storage_numel,
68
+ /*size=*/{new_storage_numel - old_storage_numel},
69
+ /*stride=*/{1});
70
+ at::native::fill_empty_deterministic_(tensor_view);
71
+ }
72
+ return tensor;
73
+ }
74
+
75
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/ReduceOpsUtils.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ namespace at::native {
9
+
10
+ namespace {
11
+
12
+ // checks whether index.dtype == int64
13
+ // and self.dtype == src.dtype if src is a Tensor
14
+ inline void scatter_gather_dtype_check(
15
+ const std::string& method_name,
16
+ const Tensor& self,
17
+ const Tensor& index,
18
+ const std::optional<Tensor>& src_opt = std::nullopt
19
+ ) {
20
+ if (index.numel() != 0) {
21
+ TORCH_CHECK(
22
+ index.scalar_type() == at::ScalarType::Long,
23
+ method_name, "(): Expected dtype int64 for index"
24
+ );
25
+ }
26
+
27
+ if (src_opt.has_value()) {
28
+ const auto& src = src_opt.value();
29
+ TORCH_CHECK(
30
+ self.scalar_type() == src.scalar_type(),
31
+ method_name, "(): Expected self.dtype to be equal to src.dtype"
32
+ );
33
+ }
34
+ }
35
+
36
+ // Used for `gather`-like methods
37
+ // Note: self means the input tensor here
38
+ // Test:
39
+ // 1. index.size(d) <= self.size(d) for all d != dim
40
+ // 2. index.dim() == self.dim()
41
+ inline void gather_shape_check(const Tensor& self, int64_t dim,
42
+ const Tensor& index
43
+ ) {
44
+ auto self_dims = ensure_nonempty_dim(self.dim());
45
+ TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
46
+ "Index tensor must have the same number of dimensions as input tensor"
47
+ );
48
+
49
+ for (const auto i : c10::irange(self_dims)) {
50
+ if (i != dim) {
51
+ TORCH_CHECK(
52
+ ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
53
+ "Size does not match at dimension ", i,
54
+ " expected index ", index.sizes(),
55
+ " to be smaller than self ", self.sizes(),
56
+ " apart from dimension ", dim
57
+ );
58
+ }
59
+ }
60
+ }
61
+
62
+ // Used for `scatter` and `scatter_add`
63
+ // Tests:
64
+ // 1. index.size(d) <= self.size(d) for all d != dim
65
+ // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
66
+ // 3. index.dim() == self.dim() == src.dim()
67
+ inline void scatter_shape_check(
68
+ const Tensor& self, int64_t dim, const Tensor& index,
69
+ const std::optional<Tensor>& src_opt = std::nullopt
70
+ ) {
71
+ if (index.numel() == 0) return;
72
+ TORCH_CHECK(
73
+ ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
74
+ "Index tensor must have the same number of dimensions as self tensor"
75
+ );
76
+
77
+ bool is_wrong_shape = false;
78
+ int64_t self_dims = ensure_nonempty_dim(self.dim());
79
+
80
+ // Check: index.size(d) <= self.size(d) for all d != dim
81
+ for (const auto d : c10::irange(self_dims)) {
82
+ int64_t index_d_size = ensure_nonempty_size(index, d);
83
+ if (d == dim) continue;
84
+ if (index_d_size > ensure_nonempty_size(self, d)) {
85
+ is_wrong_shape = true;
86
+ break;
87
+ }
88
+ }
89
+
90
+ // Check: index.size(d) <= src.size(d) for all d if src is Tensor
91
+ if (!is_wrong_shape && src_opt.has_value()) {
92
+ const auto& src = src_opt.value();
93
+ for (const auto d : c10::irange(self_dims)) {
94
+ int64_t index_d_size = ensure_nonempty_size(index, d);
95
+ if (index_d_size > ensure_nonempty_size(src, d)) {
96
+ is_wrong_shape = true;
97
+ break;
98
+ }
99
+ }
100
+ }
101
+
102
+ if (src_opt.has_value()) {
103
+ const auto& src = src_opt.value();
104
+
105
+ TORCH_CHECK(
106
+ ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
107
+ "Index tensor must have the same number of dimensions as src tensor"
108
+ );
109
+
110
+ TORCH_CHECK(!is_wrong_shape,
111
+ "Expected index ", index.sizes(),
112
+ " to be smaller than self ", self.sizes(),
113
+ " apart from dimension ", dim,
114
+ " and to be smaller size than src ", src.sizes()
115
+ );
116
+ }
117
+ else {
118
+ TORCH_CHECK(!is_wrong_shape,
119
+ "Expected index ", index.sizes(),
120
+ " to be smaller than self ", self.sizes(),
121
+ " apart from dimension ", dim
122
+ );
123
+ }
124
+ }
125
+
126
+ } // anonymous namespace
127
+
128
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // Please note that this file is
3
+ // used across both CPU and GPU.
4
+
5
+ #include <type_traits>
6
+ #include <complex>
7
+ #include <c10/macros/Macros.h>
8
+ #include <ATen/detail/FunctionTraits.h>
9
+ #include <ATen/NumericUtils.h>
10
+ #if defined(__CUDACC__)
11
+ #include <ATen/cuda/DeviceUtils.cuh>
12
+ #include <ATen/native/cuda/DeviceSqrt.cuh>
13
+ #elif defined(__HIPCC__)
14
+ #include <ATen/hip/DeviceUtils.cuh>
15
+ #include <ATen/native/hip/DeviceSqrt.cuh>
16
+ #endif
17
+ #if defined(__CUDACC__) || defined(__HIPCC__)
18
+ #include <thrust/pair.h>
19
+ #else
20
+ #include <cmath>
21
+ #define device_sqrt std::sqrt
22
+ #endif
23
+ #if defined(__CUDACC__) || defined(__HIPCC__)
24
+ template <typename scalar_t>
25
+ inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
26
+ #if defined(__HIPCC__)
27
+ // TODO: remove this special case for HIP when issue is fixed:
28
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
29
+ scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
30
+ #else
31
+ scalar_t max = at::_isnan(b) ? b : std::max(a, b);
32
+ #endif
33
+ return max;
34
+ }
35
+ template <typename scalar_t>
36
+ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
37
+ #if defined(__HIPCC__)
38
+ // TODO: remove this special case for HIP when issue is fixed:
39
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
40
+ scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
41
+ #else
42
+ scalar_t min = at::_isnan(b) ? b : std::min(a, b);
43
+ #endif
44
+ return min;
45
+ }
46
+ #define MAX(X, Y) max_propagate_nan(X,Y)
47
+ #define MIN(X, Y) min_propagate_nan(X,Y)
48
+ #else
49
+ #include <ATen/native/cpu/zmath.h>
50
+ #define MAX(X, Y) max_impl(X,Y)
51
+ #define MIN(X, Y) min_impl(X,Y)
52
+ #endif
53
+
54
+ // ROCM hcc doesn't work well with using std:: in kernel functions
55
+ #if defined(__CUDA_ARCH__)
56
+ #include <c10/cuda/CUDAMathCompat.h>
57
+ #define compat_pow c10::cuda::compat::pow
58
+ #elif defined(__HIPCC__)
59
+ #include <c10/hip/HIPMathCompat.h>
60
+ #define compat_pow c10::hip::compat::pow
61
+ #else
62
+ #define compat_pow std::pow
63
+ #endif
64
+
65
+ namespace at { namespace native {
66
+
67
+ namespace detail {
68
+
69
+ #if defined(__CUDACC__) || defined(__HIPCC__)
70
+ template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
71
+ #else
72
+ template <typename T1, typename T2> using pair = std::pair<T1, T2>;
73
+ #endif
74
+
75
+ } // namespace detail
76
+
77
+ template <typename scalar_t, typename index_t>
78
+ struct WelfordData {
79
+ scalar_t mean;
80
+ scalar_t m2;
81
+ index_t n;
82
+ scalar_t nf;
83
+
84
+ C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
85
+
86
+ C10_HOST_DEVICE WelfordData(
87
+ scalar_t mean,
88
+ scalar_t m2,
89
+ index_t n,
90
+ scalar_t nf)
91
+ : mean(mean), m2(m2), n(n), nf(nf) {}
92
+ };
93
+
94
+
95
+ template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
96
+ struct WelfordOps {
97
+ acc_scalar_t correction;
98
+ bool take_sqrt;
99
+ public:
100
+ using acc_t = WelfordData<acc_scalar_t, index_t>;
101
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
102
+ // We accumulate n in index_t to avoid cumulative rounding error, but still
103
+ // need nf for use in combine where int32 may overflow.
104
+ index_t new_n = acc.n + 1;
105
+ acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
106
+ acc_scalar_t delta = data - acc.mean;
107
+ acc_scalar_t new_mean = acc.mean + delta / new_nf;
108
+ acc_scalar_t new_delta = data - new_mean;
109
+ return {
110
+ new_mean,
111
+ acc.m2 + delta * new_delta,
112
+ new_n,
113
+ new_nf,
114
+ };
115
+ }
116
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
117
+ if (a.nf == 0) {
118
+ return b;
119
+ }
120
+ if (b.nf == 0) {
121
+ return a;
122
+ }
123
+ acc_scalar_t delta = b.mean - a.mean;
124
+ acc_scalar_t new_count = a.nf + b.nf;
125
+ acc_scalar_t nb_over_n = b.nf / new_count;
126
+ return {
127
+ a.mean + delta * nb_over_n,
128
+ a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
129
+ // setting acc.n as -1 since acc.n might not be able to represent the count
130
+ // correctly within its range, setting it to -1 to avoid confusion
131
+ -1,
132
+ new_count
133
+ };
134
+ }
135
+ inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
136
+ const auto mean = static_cast<scalar_t>(acc.mean);
137
+ const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
138
+ const auto var = acc.m2 / divisor;
139
+ res_t results(take_sqrt ? device_sqrt(var) : var, mean);
140
+ return results;
141
+ }
142
+
143
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
144
+ return acc;
145
+ }
146
+
147
+ #if defined(__CUDACC__) || defined(__HIPCC__)
148
+ inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
149
+ return {
150
+ WARP_SHFL_DOWN(acc.mean, offset)
151
+ , WARP_SHFL_DOWN(acc.m2, offset)
152
+ , WARP_SHFL_DOWN(acc.n, offset)
153
+ , WARP_SHFL_DOWN(acc.nf, offset)
154
+ };
155
+ }
156
+ #endif
157
+ C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
158
+ : correction(correction), take_sqrt(take_sqrt) {}
159
+ };
160
+
161
+ template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
162
+ struct MeanOps {
163
+ factor_t factor;
164
+
165
+ inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
166
+ return combine(a, static_cast<acc_t>(b));
167
+ }
168
+
169
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
170
+ return a + b;
171
+ }
172
+
173
+ inline C10_DEVICE out_t project(acc_t a) const {
174
+ return a * factor;
175
+ }
176
+
177
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
178
+ return acc;
179
+ }
180
+
181
+ #if defined(__CUDACC__) || defined(__HIPCC__)
182
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
183
+ return WARP_SHFL_DOWN(data, offset);
184
+ }
185
+ #endif
186
+
187
+ MeanOps(factor_t factor): factor(factor) {
188
+ }
189
+ };
190
+
191
+ // This accumulator template is used to calculate the minimum absolute value of
192
+ // a set of numbers.
193
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
194
+ // value. These types differ for complex number input support.
195
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
196
+ struct AbsMinOps {
197
+
198
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
199
+ return MIN(acc, static_cast<acc_t>(std::abs(data)));
200
+ }
201
+
202
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
203
+ return MIN(a, b);
204
+ }
205
+
206
+ inline C10_DEVICE out_t project(acc_t a) const {
207
+ return a;
208
+ }
209
+
210
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
211
+ return acc;
212
+ }
213
+
214
+ #if defined(__CUDACC__) || defined(__HIPCC__)
215
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
216
+ return WARP_SHFL_DOWN(acc, offset);
217
+ }
218
+ #endif
219
+ };
220
+
221
+ // This accumulator template is used to calculate the maximum absolute value of
222
+ // a set of numbers.
223
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
224
+ // value. These types differ for complex number input support.
225
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
226
+ struct AbsMaxOps {
227
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
228
+ return MAX(acc, static_cast<acc_t>(std::abs(data)));
229
+ }
230
+
231
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
232
+ return MAX(a, b);
233
+ }
234
+
235
+ inline C10_DEVICE out_t project(acc_t a) const {
236
+ return a;
237
+ }
238
+
239
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
240
+ return acc;
241
+ }
242
+
243
+ #if defined(__CUDACC__) || defined(__HIPCC__)
244
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
245
+ return WARP_SHFL_DOWN(acc, offset);
246
+ }
247
+ #endif
248
+ };
249
+
250
+ // This accumulator template is used to calculate the norm of the absolute value
251
+ // of a set of numbers.
252
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
253
+ // value. These types differ for complex number input support.
254
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
255
+ struct NormOps {
256
+ acc_t norm_;
257
+
258
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
259
+ return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
260
+ }
261
+
262
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
263
+ return a + b;
264
+ }
265
+
266
+ inline C10_DEVICE out_t project(acc_t a) const {
267
+ return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
268
+ }
269
+
270
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
271
+ return acc;
272
+ }
273
+
274
+ #if defined(__CUDACC__) || defined(__HIPCC__)
275
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
276
+ return WARP_SHFL_DOWN(acc, offset);
277
+ }
278
+ #endif
279
+
280
+ NormOps(acc_t norm_): norm_(norm_) {
281
+ }
282
+ };
283
+
284
+ // This accumulator template is used to calculate the order zero norm of the
285
+ // absolute value of a set of numbers.
286
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
287
+ // value. These types differ for complex number input support.
288
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
289
+ struct NormZeroOps {
290
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
291
+ return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
292
+ }
293
+
294
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
295
+ return a + b;
296
+ }
297
+
298
+ inline C10_DEVICE out_t project(acc_t a) const {
299
+ return a;
300
+ }
301
+
302
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
303
+ return acc;
304
+ }
305
+
306
+
307
+ #if defined(__CUDACC__) || defined(__HIPCC__)
308
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
309
+ return WARP_SHFL_DOWN(acc, offset);
310
+ }
311
+ #endif
312
+ };
313
+
314
+ // This accumulator template is used to calculate the order one norm of the
315
+ // absolute value of a set of numbers.
316
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
317
+ // value. These types differ for complex number input support.
318
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
319
+ struct NormOneOps {
320
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
321
+ return acc + static_cast<acc_t>(std::abs(data));
322
+ }
323
+
324
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
325
+ return a + b;
326
+ }
327
+
328
+ inline C10_DEVICE out_t project(acc_t a) const {
329
+ return a;
330
+ }
331
+
332
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
333
+ return acc;
334
+ }
335
+
336
+ #if defined(__CUDACC__) || defined(__HIPCC__)
337
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
338
+ return WARP_SHFL_DOWN(acc, offset);
339
+ }
340
+ #endif
341
+ };
342
+
343
+
344
+ template<typename acc_t>
345
+ struct AbsSwitch {};
346
+
347
+ template<typename scalar_t, typename acc_t>
348
+ inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
349
+ return static_cast<acc_t>(data);
350
+ }
351
+
352
+ template<typename scalar_t, typename acc_t>
353
+ inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
354
+ return static_cast<acc_t>(std::abs(data));
355
+ }
356
+
357
+ template<typename scalar_t, typename acc_t>
358
+ inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
359
+ return static_cast<acc_t>(std::abs(data));
360
+ }
361
+
362
+ // This accumulator template is used to calculate the order two norm of the
363
+ // absolute value of a set of numbers.
364
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
365
+ // value. These types differ for complex number input support.
366
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
367
+ struct NormTwoOps {
368
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
369
+ acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
370
+ return acc + data_ * data_;
371
+ }
372
+
373
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
374
+ return a + b;
375
+ }
376
+
377
+ inline C10_DEVICE out_t project(acc_t a) const {
378
+ return device_sqrt(a);
379
+ }
380
+
381
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
382
+ return acc;
383
+ }
384
+
385
+ #if defined(__CUDACC__) || defined(__HIPCC__)
386
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
387
+ return WARP_SHFL_DOWN(acc, offset);
388
+ }
389
+ #endif
390
+ };
391
+
392
+ template <typename acc_t, typename data_t>
393
+ struct NanSumOps {
394
+ inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
395
+ return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
396
+ }
397
+
398
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
399
+ return a + b;
400
+ }
401
+
402
+ inline C10_DEVICE data_t project(acc_t a) const {
403
+ return data_t{a};
404
+ }
405
+
406
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
407
+ return acc;
408
+ }
409
+
410
+ #if defined(__CUDACC__) || defined(__HIPCC__)
411
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
412
+ return WARP_SHFL_DOWN(data, offset);
413
+ }
414
+ #endif
415
+ };
416
+
417
+ namespace detail {
418
+
419
+ template <typename scalar_t>
420
+ struct LessOrNan {
421
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
422
+ // If (a == b), then choose the one with lower idx, else min(a, b)
423
+ if (at::_isnan(a)) {
424
+ if (at::_isnan(b)) {
425
+ return idx_a < idx_b;
426
+ }
427
+ return true;
428
+ }
429
+ return (a == b) ? idx_a < idx_b : (a < b);
430
+ }
431
+ };
432
+
433
+ template <typename scalar_t>
434
+ struct GreaterOrNan {
435
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
436
+ // If (a == b), then choose the one with lower idx, else max(a, b)
437
+ if (at::_isnan(a)) {
438
+ if (at::_isnan(b)) {
439
+ return idx_a < idx_b;
440
+ }
441
+ return true;
442
+ }
443
+ return (a == b) ? idx_a < idx_b : (a > b);
444
+ }
445
+ };
446
+
447
+ template <typename comp_t>
448
+ struct MinMaxReductionOps {
449
+ using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
450
+ using index_t = int64_t;
451
+ using arg_t = detail::pair<scalar_t, index_t>;
452
+
453
+ static C10_DEVICE arg_t project(arg_t arg) {
454
+ return arg;
455
+ }
456
+
457
+ static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
458
+ return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
459
+ }
460
+
461
+ static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
462
+ return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
463
+ }
464
+
465
+ static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
466
+ return {a.first, a.second + base_idx};
467
+ }
468
+
469
+ #if defined(__CUDACC__) || defined(__HIPCC__)
470
+ static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
471
+ return arg_t(WARP_SHFL_DOWN(arg.first, offset),
472
+ WARP_SHFL_DOWN(arg.second, offset));
473
+ }
474
+ #endif
475
+ };
476
+
477
+ template <typename comp_t>
478
+ struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
479
+ using typename MinMaxReductionOps<comp_t>::scalar_t;
480
+ using typename MinMaxReductionOps<comp_t>::index_t;
481
+ using typename MinMaxReductionOps<comp_t>::arg_t;
482
+
483
+ static C10_DEVICE index_t project(arg_t arg) {
484
+ return arg.second;
485
+ }
486
+ };
487
+
488
+ } // namespace detail
489
+
490
+ template <typename scalar_t>
491
+ struct ArgMaxOps :
492
+ public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
493
+ };
494
+
495
+ template <typename scalar_t>
496
+ struct ArgMinOps :
497
+ public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
498
+ };
499
+
500
+ template <typename scalar_t>
501
+ struct MinOps :
502
+ public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
503
+ };
504
+
505
+ template <typename scalar_t>
506
+ struct MaxOps :
507
+ public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
508
+ };
509
+
510
+ template <typename scalar_t, typename acc_scalar_t, typename index_t>
511
+ struct MinMaxOps {
512
+ using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
513
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
514
+ return combine(acc, {data, data});
515
+ }
516
+
517
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
518
+ auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
519
+ auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
520
+
521
+ return {min_val, max_val};
522
+ }
523
+
524
+ inline C10_DEVICE acc_t project(acc_t acc) const {
525
+ return acc;
526
+ }
527
+
528
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
529
+ return acc;
530
+ }
531
+
532
+ #if defined(__CUDACC__) || defined(__HIPCC__)
533
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
534
+ return {
535
+ WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
536
+ };
537
+ }
538
+ #endif
539
+ };
540
+
541
+ }} // namespace at::native
542
+
543
+ #undef MAX
544
+ #undef MIN
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ class TensorBase;
8
+ }
9
+
10
+ namespace at::native {
11
+
12
+ enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
13
+ LINEAR,
14
+ LOWER,
15
+ HIGHER,
16
+ MIDPOINT,
17
+ NEAREST
18
+ };
19
+
20
+ using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
21
+ using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
22
+
23
+ DECLARE_DISPATCH(sort_fn, sort_stub);
24
+ DECLARE_DISPATCH(topk_fn, topk_stub);
25
+
26
+ void _fill_indices(const TensorBase &indices, int64_t dim);
27
+
28
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/NumericUtils.h>
4
+ #include <ATen/native/Resize.h>
5
+ #include <c10/util/irange.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/Functions.h>
9
+ #else
10
+ #include <ATen/ops/empty.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ // ensure we get good values and indices for kthvalue, mode
16
+ // this will always be with the reducing dim as 1-d
17
+ inline void _reduction_with_indices_allocate_or_resize_output(
18
+ Tensor& values,
19
+ Tensor& indices,
20
+ const Tensor& self,
21
+ int64_t dim_,
22
+ bool keepdim) {
23
+ int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
24
+ auto result_sizes = self.sizes().vec();
25
+ if (!result_sizes.empty()) {
26
+ result_sizes[dim] = 1;
27
+ }
28
+ if (values.defined()) {
29
+ TORCH_CHECK(
30
+ self.options().type_equal(values.options()),
31
+ "output values must be of same type as input");
32
+ if (!keepdim && values.dim() == self.dim() - 1) {
33
+ // unsqueeze to preserve passed in noncontiguous tensor in resize
34
+ values.unsqueeze_(dim);
35
+ }
36
+ resize_output(values, result_sizes);
37
+ } else {
38
+ values = at::empty(result_sizes, self.options());
39
+ }
40
+ if (indices.defined()) {
41
+ TORCH_CHECK(
42
+ indices.dtype() == kLong, "output indices must be of scalar type Long");
43
+ TORCH_CHECK(
44
+ indices.device() == self.device(),
45
+ "output indices must be on same device as input");
46
+ if (!keepdim && indices.dim() == self.dim() - 1) {
47
+ // unsqueeze to preserve passed in noncontiguous tensor in resize
48
+ indices.unsqueeze_(dim);
49
+ }
50
+ resize_output(indices, result_sizes);
51
+ } else {
52
+ indices = at::empty(result_sizes, self.options().dtype(kLong));
53
+ }
54
+ }
55
+
56
+ // ensure we get good values and indices for topk
57
+ inline void _allocate_or_resize_output_with_indices(
58
+ Tensor& values,
59
+ Tensor& indices,
60
+ const Tensor& self,
61
+ int64_t dim_,
62
+ int64_t k) {
63
+ int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
64
+ auto result_sizes = self.sizes().vec();
65
+ if (!result_sizes.empty()) {
66
+ result_sizes[dim] = k;
67
+ }
68
+ if (values.defined()) {
69
+ TORCH_CHECK(
70
+ self.options().type_equal(values.options()),
71
+ "output values must be of same type as input");
72
+ values.resize_(result_sizes);
73
+ } else {
74
+ values = at::empty(result_sizes, self.options());
75
+ }
76
+ if (indices.defined()) {
77
+ TORCH_CHECK(
78
+ indices.dtype() == kLong, "output indices must be of scalar type Long");
79
+ TORCH_CHECK(
80
+ indices.device() == self.device(),
81
+ "output indices must be on same device as input");
82
+ indices.resize_(result_sizes);
83
+ } else {
84
+ indices = at::empty(result_sizes, self.options().dtype(kLong));
85
+ }
86
+ }
87
+
88
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <ATen/SparseTensorImpl.h>
5
+ #include <ATen/core/Tensor.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/Functions.h>
9
+ #else
10
+ #include <ATen/ops/empty.h>
11
+ #include <ATen/ops/tensor.h>
12
+ #endif
13
+
14
+ namespace at::sparse {
15
+
16
+ // Just for documentary purposes
17
+ using SparseTensor = Tensor;
18
+ using SparseType = Type;
19
+
20
+ // This is an internal utility function for getting at the SparseTensorImpl,
21
+ // so that we can write sparse tensor specific accessors for special fields
22
+ // in SparseTensor. You should only use this for writing low level
23
+ // setters/getters for SparseTensorImpl fields; otherwise, you should use
24
+ // the low level setters/getters that were implemented using this.
25
+ //
26
+ // This may be called repeatedly, so make sure it's pretty cheap.
27
+ inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
28
+ TORCH_INTERNAL_ASSERT(
29
+ self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
30
+ return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
31
+ }
32
+
33
+ // Takes indices and values and directly puts them into the sparse tensor, no
34
+ // copy. This used to be called THSTensor_(_move)
35
+ inline void alias_into_sparse(
36
+ const SparseTensor& self,
37
+ const Tensor& indices,
38
+ const Tensor& values) {
39
+ get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
40
+ }
41
+
42
+ // Take indices and values and makes a (data) copy of them to put into the
43
+ // sparse indices/values. This used to be called THSTensor_(_set)
44
+ inline void copy_into_sparse(
45
+ const SparseTensor& self,
46
+ const Tensor& indices,
47
+ const Tensor& values,
48
+ bool non_blocking) {
49
+ alias_into_sparse(
50
+ self,
51
+ indices.to(self._indices().options(), non_blocking, /*copy=*/true),
52
+ values.to(self._values().options(), non_blocking, /*copy=*/true));
53
+ }
54
+
55
+ // TODO: put this into the public API
56
+ inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
57
+ return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
58
+ }
59
+
60
+ inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
61
+ return self.sparse_dim() == src.sparse_dim() &&
62
+ self.dense_dim() == src.dense_dim();
63
+ }
64
+
65
+ // Give us a new values tensor, with the same dimensionality
66
+ // as 'values' but with a new number of non-zero elements.
67
+ // TODO: Expose this for real in ATen, some day?
68
+ // NB: Doesn't preserve data.
69
+ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
70
+ std::vector<int64_t> size = values.sizes().vec();
71
+ size[0] = nnz;
72
+ return at::empty(size, values.options());
73
+ }
74
+
75
+ // NOTE [ Flatten Sparse Indices ]
76
+ // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
77
+ // indices tensor. E.g.,
78
+ // input = [[2, 4, 0],
79
+ // [3, 1, 10]]
80
+ // full_size = [2, 12]
81
+ // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
82
+ //
83
+ // In other words, assuming that each `indices[i, :]` is a valid index to a
84
+ // tensor `t` of shape `full_size`. This returns the corresponding indices to
85
+ // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
86
+ // if forceClone is true, the result will forced to be a clone of self.
87
+ // if force_clone is true, the result will forced to be a clone of self.
88
+ TORCH_API Tensor flatten_indices(
89
+ const Tensor& indices,
90
+ IntArrayRef full_size,
91
+ bool force_clone = false);
92
+
93
+ // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
94
+ // Sparse Indices ], except this one allows partial flatten: only flatten on
95
+ // specified dims. Note that the flatten indices might be uncoalesced if
96
+ // dims_to_flatten.size() < sparse_dim. Also if input indices is already
97
+ // coalesced, the flattened indices will also be sorted.
98
+ //
99
+ // args:
100
+ // indices: sparse tensor indices
101
+ // sizes: sparse tensor sizes
102
+ // dims_to_flatten: a list of dim index to flatten
103
+ //
104
+ // Ex1:
105
+ // indices = [[2, 4, 0],
106
+ // [3, 1, 3]]
107
+ // sizes = [2, 12]
108
+ // dims_to_flatten = [0, 1]
109
+ // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
110
+ //
111
+ // Ex2:
112
+ // dims_to_flatten = [1]
113
+ // new_indices = [ 3, 1, 3 ] # uncoalesced
114
+ TORCH_API Tensor flatten_indices_by_dims(
115
+ const Tensor& indices,
116
+ const IntArrayRef& sizes,
117
+ const IntArrayRef& dims_to_flatten);
118
+
119
+ // Find the CSR representation for a row `indices` from the COO format
120
+ TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
121
+
122
+ TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
123
+
124
+ template <size_t static_shape_max_len>
125
+ class TensorGeometryHolder {
126
+ using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
127
+
128
+ public:
129
+ explicit TensorGeometryHolder(
130
+ IntArrayRef sizes,
131
+ IntArrayRef strides,
132
+ TensorOptions options = {}) {
133
+ std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
134
+ std::copy(strides.begin(), strides.end(), t_strides.begin());
135
+ }
136
+
137
+ explicit TensorGeometryHolder(const Tensor& t)
138
+ : TensorGeometryHolder(t.sizes(), t.strides()) {}
139
+
140
+ auto operator*() const {
141
+ return std::make_tuple(t_sizes, t_strides);
142
+ }
143
+
144
+ private:
145
+ geometry_holder_t t_sizes;
146
+ geometry_holder_t t_strides;
147
+ };
148
+
149
+ template <>
150
+ class TensorGeometryHolder<0> {
151
+ using geometry_holder_t = Tensor;
152
+
153
+ public:
154
+ explicit TensorGeometryHolder(
155
+ IntArrayRef sizes,
156
+ IntArrayRef strides,
157
+ TensorOptions options) {
158
+ const int64_t t_ndims = sizes.size();
159
+ const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
160
+ Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
161
+ t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
162
+ t_sizes_and_strides_cpu.select(0, 1).copy_(
163
+ at::tensor(strides, cpu_options));
164
+ const Tensor t_sizes_and_strides =
165
+ t_sizes_and_strides_cpu.to(options.device());
166
+ t_sizes = t_sizes_and_strides.select(0, 0);
167
+ t_strides = t_sizes_and_strides.select(0, 1);
168
+ }
169
+
170
+ explicit TensorGeometryHolder(const Tensor& t)
171
+ : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
172
+
173
+ auto operator*() const {
174
+ return std::make_tuple(
175
+ t_sizes.template data_ptr<int64_t>(),
176
+ t_strides.template data_ptr<int64_t>());
177
+ }
178
+
179
+ private:
180
+ geometry_holder_t t_sizes;
181
+ geometry_holder_t t_strides;
182
+ };
183
+
184
+ // Return all indices of a tensor with the given shape.
185
+ //
186
+ // full_coo_indices(shape) is equivalent to
187
+ // torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
188
+ TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
189
+
190
+ } // namespace at::sparse
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // Indexing tensors by tensors
4
+
5
+ #include <ATen/core/List.h>
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/DispatchStub.h>
8
+ #include <ATen/native/ReductionType.h>
9
+
10
+ namespace at {
11
+ struct TensorIterator;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
17
+ using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
18
+ using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
19
+ using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
20
+ using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
21
+ using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
22
+ using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
23
+ const Tensor& src, const ReductionType& reduce);
24
+ using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
25
+ const Scalar& value, const ReductionType& reduce);
26
+ using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
27
+ const Tensor& src, const ReductionType& reduce);
28
+
29
+ DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
30
+ DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
31
+ DECLARE_DISPATCH(gather_fn, gather_stub);
32
+ DECLARE_DISPATCH(scatter_fn, scatter_stub);
33
+ DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
34
+ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
35
+ DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
36
+ DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
37
+ DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
38
+
39
+ TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices);
40
+
41
+ using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
42
+ using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
43
+ using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
44
+
45
+ DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
46
+ DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
47
+ DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
48
+
49
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/native/IndexingUtils.h>
4
+ #include <ATen/native/TensorIterator.h>
5
+
6
+ namespace at::native {
7
+ namespace {
8
+ #ifndef STRIP_ERROR_MESSAGES
9
+ inline std::string shapes_as_str(TensorList tensors) {
10
+ std::ostringstream os;
11
+ bool first = true;
12
+ for (auto& tensor : tensors) {
13
+ if (tensor.defined()) {
14
+ if (!first) {
15
+ os << ", ";
16
+ }
17
+ os << tensor.sizes();
18
+ first = false;
19
+ }
20
+ }
21
+ return os.str();
22
+ }
23
+ #endif
24
+ } // anonymous namespace
25
+
26
+ inline std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<std::optional<at::Tensor>>& indices,
27
+ const Tensor& value){
28
+ if (!(value.numel() ==1 && value.device().is_cpu())){
29
+ return std::make_tuple(false,Tensor());
30
+ }
31
+ int64_t num_ind = 0;
32
+ Tensor mask;
33
+ auto self_device = self.device();
34
+ for (const std::optional<Tensor>& i: indices) {
35
+ if (!i.has_value() || !(*i).defined()){
36
+ num_ind++;
37
+ } else {
38
+ const Tensor &index = *i;
39
+ if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
40
+ index.device() != self_device || mask.defined()){
41
+ return std::make_tuple(false, Tensor());
42
+ } else {
43
+ mask = index;
44
+ for (const auto j : c10::irange(index.dim())) {
45
+ int64_t srcIdx = num_ind + j;
46
+ TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
47
+ " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
48
+ }
49
+ num_ind += mask.ndimension();
50
+ }
51
+ }
52
+ }
53
+ for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
54
+ mask = mask.unsqueeze(-1);
55
+ }
56
+ return std::make_tuple(true, mask);
57
+ }
58
+
59
+ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
60
+ checkIndexTensorTypes(orig, /*allow_int*/ true);
61
+ // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
62
+ auto indices = expandTensors(self, orig);
63
+ // next broadcast all index tensors together
64
+ try {
65
+ indices = expand_outplace(indices);
66
+ } catch (std::exception& e) {
67
+ TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
68
+ " with shapes ", shapes_as_str(indices));
69
+ }
70
+ // add missing null Tensors so that it matches self.dim()
71
+ while (indices.size() < (size_t)self.dim()) {
72
+ indices.emplace_back();
73
+ }
74
+ // if the non-null indices are not all adjacent, transpose self and indices
75
+ // together so that they're adjacent at the front
76
+ if (!hasContiguousSubspace(indices)) {
77
+ std::tie(self, indices) = transposeToFront(self, indices);
78
+ }
79
+ // Ensure indices are on the same device as self
80
+ for (auto & indice : indices) {
81
+ if (indice.defined() && indice.device() != self.device()) {
82
+ indice = indice.to(self.device());
83
+ }
84
+ }
85
+ for (auto & indice : indices) {
86
+ if (indice.defined() && indice.dtype() == at::kInt) {
87
+ indice = indice.to(at::kLong);
88
+ }
89
+ }
90
+
91
+ return AdvancedIndex(self, indices);
92
+ }
93
+
94
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+
5
+ namespace at::native {
6
+ //input tensors are non-zero dim and non-empty
7
+ template<typename T1, typename T2, typename Function>
8
+
9
+ void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
10
+ int ndims = self.dim();
11
+ int tensor_dim_apply_has_finished = 0;
12
+ std::vector<int64_t> counter(ndims, 0);
13
+ const T1* self_data = self.const_data_ptr<T1>();
14
+ T1* values_data = values.data_ptr<T1>();
15
+ T2* indices_data = indices.data_ptr<T2>();
16
+ int64_t self_stride = self.stride(dim);
17
+ int64_t values_stride = values.stride(dim);
18
+ int64_t indices_stride = indices.stride(dim);
19
+ int self_dim_size = self.size(dim);
20
+
21
+ while (!tensor_dim_apply_has_finished) {
22
+ func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
23
+ if (ndims == 1) {
24
+ break;
25
+ }
26
+ for (const auto dim_i : c10::irange(ndims)) {
27
+ if (dim_i == dim) {
28
+ if (dim_i == (ndims - 1)) {
29
+ tensor_dim_apply_has_finished = 1;
30
+ break;
31
+ }
32
+ continue;
33
+ }
34
+ counter[dim_i]++;
35
+ self_data += self.stride(dim_i);
36
+ values_data += values.stride(dim_i);
37
+ indices_data += indices.stride(dim_i);
38
+
39
+ if (counter[dim_i] == self.size(dim_i)) {
40
+ if (dim_i == ndims-1) {
41
+ tensor_dim_apply_has_finished = 1;
42
+ break;
43
+ } else {
44
+ self_data -= counter[dim_i]*self.stride(dim_i);
45
+ values_data -= counter[dim_i]*values.stride(dim_i);
46
+ indices_data -= counter[dim_i]*indices.stride(dim_i);
47
+ counter[dim_i] = 0;
48
+ }
49
+ } else {
50
+ break;
51
+ }
52
+ }
53
+ }
54
+ }
55
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/EmptyTensor.h>
5
+ #include <ATen/TensorIterator.h>
6
+ #include <ATen/Dispatch.h>
7
+ #include <ATen/Dispatch_v2.h>
8
+ #include <ATen/native/DispatchStub.h>
9
+
10
+ #ifndef AT_PER_OPERATOR_HEADERS
11
+ #include <ATen/Functions.h>
12
+ #else
13
+ #include <ATen/ops/scalar_tensor.h>
14
+ #endif
15
+
16
+ namespace at::native {
17
+ // Different combinations of row, col, and offset can lead to two cases:
18
+ //
19
+ // Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
20
+ // Example A: offset > 0
21
+ // 1 1 0 0 0
22
+ // 1 1 1 0 0
23
+ // 1 1 1 1 0
24
+ // Example B: offset <= 0
25
+ // 0 0 0
26
+ // 1 0 0
27
+ // 1 1 0
28
+ // In this case, we calculate the number of elements in the first row and
29
+ // last row of the tril respectively, and then compute the tril size.
30
+ //
31
+ // Case 2 - Trapezoid + Rectangle: row + offset > col
32
+ // Example:
33
+ // 1 1 0
34
+ // 1 1 1
35
+ // 1 1 1
36
+ // In this case, we first calculate the size of top trapezoid, and then
37
+ // calculate the size of the bottom rectangle.
38
+ inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
39
+ // If either dimension is 0 then the there is no tril
40
+ if (row == 0 || col == 0) {
41
+ return 0;
42
+ }
43
+ // number of elements in the first row of the tril
44
+ auto m_first_row = offset > 0 ?
45
+ std::min<int64_t>(col, 1 + offset) : // upper bounded by col
46
+ row + offset > 0; // either 0 or 1
47
+ // number of elements in the last row of the tril, bounded by [0, col]
48
+ auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
49
+ // number of rows, bounded by [0, row]
50
+ auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
51
+ auto n_row_trapezoid = (m_last_row - m_first_row + 1);
52
+
53
+ // calculate # of elements in the top trapezoid
54
+ auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
55
+
56
+ // calculate # of elements in the bottom rectangle if there is any
57
+ auto diff_row = n_row_all - n_row_trapezoid;
58
+ if (diff_row > 0) {
59
+ tril_size += diff_row * col;
60
+ }
61
+
62
+ return tril_size;
63
+ }
64
+
65
+ inline void check_args(
66
+ int64_t row, int64_t col, std::optional<Layout> layout_opt) {
67
+ TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
68
+ TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
69
+ if (layout_opt.has_value()) {
70
+ TORCH_CHECK(
71
+ *layout_opt == at::kStrided,
72
+ "only support layout=torch.strided, got",
73
+ *layout_opt)
74
+ }
75
+ }
76
+
77
+ using at::check_size_nonnegative;
78
+
79
+ // assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
80
+ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
81
+ // match defined() to behavior of checks below
82
+ TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
83
+ "n is too large for result tensor type: '", tensor.toString(), "'");
84
+
85
+ // Ensure sufficient precision for floating point representation.
86
+ switch (tensor.scalar_type()) {
87
+ case at::ScalarType::Half:
88
+ TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
89
+ break;
90
+ case at::ScalarType::Float:
91
+ TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
92
+ break;
93
+ case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
94
+ TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
95
+ break;
96
+ default:
97
+ break;
98
+ }
99
+ }
100
+
101
+ // Called by `empty*` functions when deterministic algorithms are enabled to
102
+ // fill the tensor with NaN if it is floating point or complex type, or fill
103
+ // with max value if it is integer type
104
+ inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
105
+ if (tensor.is_floating_point() || tensor.is_complex()) {
106
+ AT_DISPATCH_V2(
107
+ tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
108
+ tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
109
+ }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
110
+ } else {
111
+ AT_DISPATCH_V2(
112
+ tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
113
+ tensor.fill_(std::numeric_limits<scalar_t>::max());
114
+ }), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
115
+ }
116
+ return tensor;
117
+ }
118
+
119
+ // The ZeroTensor allocator ignores whatever allocation is requested and always
120
+ // gives you nullptr
121
+ struct ZeroTensorAllocator final : public at::Allocator {
122
+ ZeroTensorAllocator(at::Device device) : device_(device) {};
123
+ ~ZeroTensorAllocator() override = default;
124
+ static void deleter(void* const pointer) {
125
+ TORCH_INTERNAL_ASSERT(!pointer);
126
+ }
127
+ DataPtr allocate(const size_t /*nbytes*/) override {
128
+ return {nullptr, nullptr, &deleter, device_};
129
+ }
130
+ DeleterFnPtr raw_deleter() const override {
131
+ return deleter;
132
+ }
133
+ void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const final {}
134
+ at::Device device_;
135
+ };
136
+
137
+ using binary_fn = void (*)(TensorIterator&);
138
+
139
+ DECLARE_DISPATCH(binary_fn, complex_stub);
140
+ DECLARE_DISPATCH(binary_fn, polar_stub);
141
+
142
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <complex>
4
+ #include <type_traits>
5
+ #include <c10/core/ScalarType.h>
6
+ #include <ATen/detail/FunctionTraits.h>
7
+ #include <ATen/native/TensorIterator.h>
8
+
9
+
10
+ // This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
11
+
12
+ // dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
13
+ // to the function that is being called.
14
+ // On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
15
+ // On CPU, there is currently an internal assert that a dynamic_cast is not needed.
16
+
17
+ namespace at::native {
18
+
19
+ // `needs_dynamic_casting` compares the types expected by iterator
20
+ // (i.e. dtypes of the operands) with the actual type of the arguments
21
+ // (and returns) of func_t
22
+ template<typename func_t, int nargs=function_traits<func_t>::arity>
23
+ struct needs_dynamic_casting {
24
+ static bool check(TensorIteratorBase& iter) {
25
+ using traits = function_traits<func_t>;
26
+ using cpp_type = typename traits::template arg<nargs - 1>::type;
27
+ using cpp_map = c10::CppTypeToScalarType<cpp_type>;
28
+
29
+ if (iter.input_dtype(nargs-1) != cpp_map::value) {
30
+ return true;
31
+ }
32
+ return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
33
+ }
34
+ };
35
+
36
+ template<typename func_t>
37
+ struct needs_dynamic_casting<func_t, 0> {
38
+ static bool check(TensorIteratorBase& iter) {
39
+ using traits = function_traits<func_t>;
40
+ using cpp_type = typename traits::result_type;
41
+
42
+ // we could assert output numbers are correct here, but checks
43
+ // (including arity) are currently pushed outside of this struct.
44
+ if constexpr (std::is_void_v<cpp_type>) {
45
+ return false;
46
+ } else {
47
+ return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value;
48
+ }
49
+ }
50
+ };
51
+
52
+ } //namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at::native {
7
+
8
+ TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
9
+
10
+ inline bool cat_should_skip_tensor(const Tensor& t) {
11
+ return t.sym_numel() == 0 && t.dim() == 1;
12
+ }
13
+
14
+ // Check to see if the shape of tensors is compatible
15
+ // for being concatenated along a given dimension.
16
+ inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
17
+ int64_t first_dims = first.dim();
18
+ int64_t second_dims = second.dim();
19
+ TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
20
+ first_dims, " and ", second_dims);
21
+ for (const auto dim : c10::irange(first_dims)) {
22
+ if (dim == dimension) {
23
+ continue;
24
+ }
25
+ int64_t first_dim_size = first.sizes()[dim];
26
+ int64_t second_dim_size = second.sizes()[dim];
27
+ TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
28
+ dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
29
+ }
30
+ }
31
+
32
+ inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
33
+ int64_t i = 0;
34
+ for(const Tensor& t : tensors) {
35
+ TORCH_CHECK(t.dim() > 0,
36
+ "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
37
+ i++;
38
+ }
39
+ }
40
+
41
+ inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
42
+ TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
43
+ TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
44
+ int64_t dim_size = self.size(dim);
45
+ TORCH_CHECK(split_size > 0 || dim_size == 0,
46
+ "split_size can only be 0 if dimension size is 0, "
47
+ "but got dimension size of ", dim_size);
48
+ // if split_size is 0 and dimension size is 0, there is 1 split.
49
+ int64_t num_splits = 1;
50
+ if (split_size != 0) {
51
+ // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
52
+ // (returns a single split). We might want to error here, but keep it for BC.
53
+ num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
54
+ }
55
+ return num_splits;
56
+ }
57
+
58
+ inline bool have_same_ndims(TensorList tensors) {
59
+ auto ndim = tensors[0].dim();
60
+ for (const auto tensor_idx : c10::irange(tensors.size())) {
61
+ if(tensors[tensor_idx].dim() != ndim) {
62
+ return false;
63
+ }
64
+ }
65
+ return true;
66
+ }
67
+
68
+ inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
69
+ auto tensor_zero_size = tensors[0].sizes();
70
+ std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
71
+ for (const auto i : c10::irange(tensors.size())) {
72
+ at::Tensor tensor = tensors[i];
73
+ for(const auto j : c10::irange(dim)) {
74
+ TORCH_CHECK(
75
+ tensor.size(j) == leading_dim_sizes[j],
76
+ "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
77
+ );
78
+ }
79
+ }
80
+ }
81
+
82
+ inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
83
+ TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
84
+ TORCH_CHECK(!tensors.empty(),
85
+ "_chunk_cat expects a non-empty input tensor list");
86
+ auto expected_dtype = tensors[0].dtype();
87
+ auto expected_device = tensors[0].device();
88
+ for(const auto i : c10::irange(tensors.size())) {
89
+ TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
90
+ TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
91
+ TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
92
+ }
93
+ if (have_same_ndims(tensors)) {
94
+ dim = maybe_wrap_dim(dim, tensors[0].dim());
95
+ } else {
96
+ TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
97
+ for(const auto i : c10::irange(tensors.size())) {
98
+ TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
99
+ }
100
+ }
101
+ leading_dimension_matches(tensors, dim);
102
+ return dim;
103
+ }
104
+
105
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/native/LinearAlgebraUtils.h>
3
+
4
+ namespace at::native {
5
+
6
+ /*
7
+ * Given batches of matrices with arbitrary batch dim,
8
+ * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
9
+ */
10
+ static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
11
+ int64_t result = 1;
12
+ for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
13
+ if (batched_matrices.stride(i) != 0) {
14
+ result *= batched_matrices.size(i);
15
+ }
16
+ }
17
+ return result;
18
+ }
19
+
20
+ /* Checks a necessary property for the triu and tril implementations, hence the name.
21
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
22
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
23
+ */
24
+ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
25
+ // Complete contiguity is the most desired property, which is why
26
+ // we return true if the tensor is contiguous
27
+ if (tensor.is_contiguous()) {
28
+ auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
29
+ if (tensor.strides() == default_strides_for_size) {
30
+ return std::make_tuple(true, tensor);
31
+ } else {
32
+ return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
33
+ }
34
+ }
35
+
36
+ int64_t dims = tensor.dim();
37
+
38
+ // Tensors with dimension less than 4 are handled by default
39
+ if (allow_zero_stride && dims <= 3) {
40
+ return std::make_tuple(true, tensor);
41
+ }
42
+
43
+ int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
44
+ for (int64_t i = dims - 3; i >= 0; i--) {
45
+ // Skip trivial dimension;
46
+ if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
47
+ continue;
48
+ }
49
+ if (expected_stride != tensor.stride(i)) {
50
+ return std::make_tuple(false, tensor.contiguous());
51
+ }
52
+ expected_stride *= tensor.size(i);
53
+ }
54
+ return std::make_tuple(true, tensor);
55
+ }
56
+
57
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/Generator.h>
5
+ #include <c10/core/Scalar.h>
6
+ #include <stdexcept>
7
+
8
+ namespace at {
9
+ class Tensor;
10
+ class TensorBase;
11
+ struct TensorIteratorBase;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ using unary_fn = void(*)(TensorIteratorBase&);
17
+ using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
18
+
19
+ inline namespace CPU_CAPABILITY {
20
+ void conj_kernel(TensorIteratorBase &iter);
21
+ void neg_kernel(TensorIteratorBase &iter);
22
+ void reciprocal_kernel(TensorIteratorBase &iter);
23
+ void rsqrt_kernel(TensorIteratorBase& iter);
24
+ void sqrt_kernel(TensorIteratorBase& iter);
25
+ } // namespace CPU_CAPABILITY
26
+
27
+ DECLARE_DISPATCH(unary_fn, abs_stub);
28
+ DECLARE_DISPATCH(unary_fn, angle_stub);
29
+ DECLARE_DISPATCH(unary_fn, conj_physical_stub);
30
+ DECLARE_DISPATCH(unary_fn, acos_stub);
31
+ DECLARE_DISPATCH(unary_fn, acosh_stub);
32
+ DECLARE_DISPATCH(unary_fn, asinh_stub);
33
+ DECLARE_DISPATCH(unary_fn, atanh_stub);
34
+ DECLARE_DISPATCH(unary_fn, asin_stub);
35
+ DECLARE_DISPATCH(unary_fn, atan_stub);
36
+ DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
37
+ DECLARE_DISPATCH(unary_fn, logical_not_stub);
38
+ DECLARE_DISPATCH(unary_fn, ceil_stub);
39
+ DECLARE_DISPATCH(unary_fn, cos_stub);
40
+ DECLARE_DISPATCH(unary_fn, cosh_stub);
41
+ DECLARE_DISPATCH(unary_fn, digamma_stub);
42
+ DECLARE_DISPATCH(unary_fn, special_entr_stub);
43
+ DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
44
+ DECLARE_DISPATCH(unary_fn, erf_stub);
45
+ DECLARE_DISPATCH(unary_fn, erfc_stub);
46
+ DECLARE_DISPATCH(unary_fn, erfinv_stub);
47
+ DECLARE_DISPATCH(unary_fn, exp_stub);
48
+ DECLARE_DISPATCH(unary_fn, exp2_stub);
49
+ DECLARE_DISPATCH(unary_fn, expm1_stub);
50
+ DECLARE_DISPATCH(unary_fn, floor_stub);
51
+ DECLARE_DISPATCH(unary_fn, frac_stub);
52
+ DECLARE_DISPATCH(unary_fn, frexp_stub);
53
+ DECLARE_DISPATCH(unary_fn, i0_stub);
54
+ DECLARE_DISPATCH(unary_fn, special_i0e_stub);
55
+ DECLARE_DISPATCH(unary_fn, special_i1_stub);
56
+ DECLARE_DISPATCH(unary_fn, special_i1e_stub);
57
+ DECLARE_DISPATCH(unary_fn, log_stub);
58
+ DECLARE_DISPATCH(unary_fn, log10_stub);
59
+ DECLARE_DISPATCH(unary_fn, log1p_stub);
60
+ DECLARE_DISPATCH(unary_fn, log2_stub);
61
+ DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
62
+ DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
63
+ DECLARE_DISPATCH(unary_fn, neg_stub);
64
+
65
+ DECLARE_DISPATCH(unary_fn, reciprocal_stub);
66
+ DECLARE_DISPATCH(unary_fn, round_stub);
67
+ DECLARE_DISPATCH(unary_fn, rsqrt_stub);
68
+ DECLARE_DISPATCH(unary_fn, sigmoid_stub);
69
+ DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
70
+ DECLARE_DISPATCH(unary_fn, sign_stub);
71
+ DECLARE_DISPATCH(unary_fn, signbit_stub);
72
+ DECLARE_DISPATCH(unary_fn, sgn_stub);
73
+ DECLARE_DISPATCH(unary_fn, sin_stub);
74
+ DECLARE_DISPATCH(unary_fn, sinc_stub);
75
+ DECLARE_DISPATCH(unary_fn, sinh_stub);
76
+ DECLARE_DISPATCH(unary_fn, sqrt_stub);
77
+ DECLARE_DISPATCH(unary_fn, tan_stub);
78
+ DECLARE_DISPATCH(unary_fn, tanh_stub);
79
+ DECLARE_DISPATCH(unary_fn, trigamma_stub);
80
+ DECLARE_DISPATCH(unary_fn, trunc_stub);
81
+ DECLARE_DISPATCH(unary_fn, lgamma_stub);
82
+ DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
83
+ DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
84
+ DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
85
+ DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
86
+ DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
87
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
88
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
89
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
90
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
91
+ DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
92
+ DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
93
+ DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
94
+
95
+ // NB: these are actually defined in Distribution
96
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub);
97
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub);
98
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub);
99
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub);
100
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub);
101
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub);
102
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub);
103
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub);
104
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub);
105
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub);
106
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub);
107
+
108
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
109
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
110
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
111
+ DECLARE_DISPATCH(
112
+ void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
113
+ multinomial_with_replacement_stub);
114
+ DECLARE_DISPATCH(
115
+ void (*)(
116
+ TensorIteratorBase&,
117
+ std::optional<double>,
118
+ std::optional<double>,
119
+ std::optional<double>),
120
+ nan_to_num_stub);
121
+ DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
122
+
123
+ // Missing unary functions
124
+ // digamma
125
+ // lgamma
126
+ // erfinv
127
+ // clone
128
+ // contiguous
129
+ // zero
130
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <cstdint>
6
+
7
+ namespace at::native {
8
+
9
+ using unfold2d_copy_fn = void (*)(
10
+ ScalarType dtype,
11
+ void *finput,
12
+ const void *input,
13
+ int64_t kH,
14
+ int64_t kW,
15
+ int64_t dH,
16
+ int64_t dW,
17
+ int64_t padH,
18
+ int64_t padW,
19
+ int64_t n_input_plane,
20
+ int64_t input_height,
21
+ int64_t input_width,
22
+ int64_t output_height,
23
+ int64_t output_width,
24
+ bool is_channels_last
25
+ );
26
+
27
+ using unfold2d_acc_fn = void (*)(
28
+ ScalarType dtype,
29
+ void *finput,
30
+ void *input,
31
+ int64_t kH,
32
+ int64_t kW,
33
+ int64_t dH,
34
+ int64_t dW,
35
+ int64_t padH,
36
+ int64_t padW,
37
+ int64_t n_input_plane,
38
+ int64_t input_height,
39
+ int64_t input_width,
40
+ int64_t output_height,
41
+ int64_t output_width,
42
+ bool is_channels_last
43
+ );
44
+
45
+ DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub);
46
+ DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub);
47
+
48
+ } // namespace at::native